/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/library/utility/gpu_verification.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/library/utility/gpu_verification.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/3643/include/ck/library/utility/gpu_verification.hpp Source File
gpu_verification.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iomanip>
7 #include <iostream>
8 
11 #include "ck/utility/type.hpp"
15 
16 namespace ck {
17 namespace profiler {
18 
19 // Result struct for GPU verification with detailed error reporting
20 // Provides backward compatibility via operator bool()
22 {
23  unsigned long long error_count; // Number of elements that exceeded tolerance
24  float max_error; // Maximum error value observed
25  std::size_t total; // Total number of elements compared
26  bool all_zero; // True if device result is all zeros (likely kernel issue)
27 
28  // Implicit conversion to bool for backward compatibility
29  // Allows: if (gpu_verify(...)) { ... }
30  operator bool() const { return error_count == 0; }
31 
32  // Calculate error percentage
33  float error_percentage() const
34  {
35  if(total == 0)
36  return 0.0f;
37  return static_cast<float>(error_count) / static_cast<float>(total) * 100.0f;
38  }
39 
40  // Print error summary to stderr (matches check_err format)
41  void print_error_summary() const
42  {
43  if(error_count > 0)
44  {
45  if(all_zero)
46  {
47  std::cerr << "WARNING: Device result is all zeros - kernel may not have executed "
48  "properly!"
49  << std::endl;
50  }
51  std::cerr << "max err: " << max_error;
52  std::cerr << ", number of errors: " << error_count;
53  std::cerr << ", " << std::setprecision(2) << std::fixed << error_percentage()
54  << "% wrong values" << std::endl;
55  }
56  }
57 };
58 
59 // Compute relative tolerance for GPU verification
60 // Matches the logic of ck::utils::get_relative_threshold but handles all types
61 template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
62 inline float compute_relative_tolerance(const int number_of_accumulations = 1)
63 {
64  using F16 = ck::half_t;
65  using BF16 = ck::bhalf_t;
66  using F32 = float;
67  using I8 = int8_t;
68  using I16 = int16_t;
69  using I32 = int32_t;
70 
71  // For integer types, tolerance is 0
72  if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I16> ||
73  std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>)
74  {
75  return 0.0f;
76  }
77  // For types supported by get_relative_threshold, use it
78  else if constexpr((std::is_same_v<ComputeDataType, F16> ||
79  std::is_same_v<ComputeDataType, BF16> ||
80  std::is_same_v<ComputeDataType, F32>) &&
81  (std::is_same_v<OutDataType, F16> || std::is_same_v<OutDataType, BF16> ||
82  std::is_same_v<OutDataType, F32>) &&
83  (std::is_same_v<AccDataType, F16> || std::is_same_v<AccDataType, BF16> ||
84  std::is_same_v<AccDataType, F32>))
85  {
86  return static_cast<float>(
87  ck::utils::get_relative_threshold<ComputeDataType, OutDataType, AccDataType>(
88  number_of_accumulations));
89  }
90  // For unsupported types (FP8, BF8, etc.), use default tolerances based on output type
91  else
92  {
93  if constexpr(std::is_same_v<OutDataType, F16>)
94  {
95  return 1e-3f;
96  }
97  else if constexpr(std::is_same_v<OutDataType, BF16>)
98  {
99  return 1e-1f;
100  }
101  else
102  {
103  // For FP8/BF8 and other types, use conservative tolerance
104  return 1e-1f;
105  }
106  }
107 }
108 
109 // Device-side result structure for kernel output
110 // Packed into a single struct to minimize device memory allocations
112 {
113  unsigned long long error_count; // Number of errors found
114  float max_error; // Maximum error value
115  int all_zero; // 1 = device result is all zeros, 0 = has non-zero values
116 };
117 
118 // GPU verification kernel - compares device result against reference using relative and absolute
119 // tolerance. Tracks all errors (no early exit) to provide detailed error reporting.
120 //
121 // Uses LDS (shared memory) for block-level reduction to minimize atomic contention.
122 // This reduces atomic operations from O(errors) to O(blocks), providing massive speedup
123 // when there are many errors.
124 //
125 // Assumption: Block size is 256
126 template <typename T>
127 __global__ void gpu_verify_kernel(const T* __restrict__ device_result,
128  const T* __restrict__ reference_result,
129  float rtol,
130  float atol,
131  long long size,
132  GpuVerifyDeviceResult* result)
133 {
134  constexpr int block_size = 256;
135 
136  // Shared memory for block-level reduction
137  __shared__ unsigned long long shared_error_count[block_size];
138  __shared__ float shared_max_error[block_size];
139  __shared__ int shared_has_error[block_size];
140  __shared__ int shared_has_nonzero[block_size];
141 
142  // Thread-local accumulators (in registers)
143  unsigned long long local_error_count = 0;
144  float local_max_error = 0.0f;
145  int local_has_error = 0;
146  int local_has_nonzero = 0;
147 
148  // Grid-stride loop to handle any tensor size
149  long long idx = blockIdx.x * blockDim.x + threadIdx.x;
150  long long stride = blockDim.x * gridDim.x;
151 
152  for(long long i = idx; i < size; i += stride)
153  {
154  // Convert to float for comparison
155  float dev_val = type_convert<float>(device_result[i]);
156  float ref_val = type_convert<float>(reference_result[i]);
157 
158  // Check if device value is non-zero
159  if(dev_val != 0.0f)
160  {
161  local_has_nonzero = 1;
162  }
163 
164  // Compute absolute difference
165  float abs_diff = fabsf(dev_val - ref_val);
166 
167  // Check tolerance (matches CPU check_err logic: err > atol + rtol * abs(ref))
168  if(abs_diff > atol + rtol * fabsf(ref_val))
169  {
170  local_has_error = 1;
171  local_error_count++;
172  local_max_error = fmaxf(local_max_error, abs_diff);
173  }
174  }
175 
176  // Store thread-local results to shared memory
177  shared_error_count[threadIdx.x] = local_error_count;
178  shared_max_error[threadIdx.x] = local_max_error;
179  shared_has_error[threadIdx.x] = local_has_error;
180  shared_has_nonzero[threadIdx.x] = local_has_nonzero;
181  __syncthreads();
182 
183  // Block-level reduction: 256 -> 128 -> 64 -> 32
184  for(unsigned int s = block_size / 2; s >= 32; s >>= 1)
185  {
186  if(threadIdx.x < s)
187  {
188  shared_error_count[threadIdx.x] += shared_error_count[threadIdx.x + s];
189  shared_max_error[threadIdx.x] =
190  fmaxf(shared_max_error[threadIdx.x], shared_max_error[threadIdx.x + s]);
191  shared_has_error[threadIdx.x] |= shared_has_error[threadIdx.x + s];
192  shared_has_nonzero[threadIdx.x] |= shared_has_nonzero[threadIdx.x + s];
193  }
194  __syncthreads();
195  }
196 
197  // Final reduction of remaining 32 elements in thread 0
198  if(threadIdx.x == 0)
199  {
200  for(int i = 1; i < 32; ++i)
201  {
202  shared_error_count[0] += shared_error_count[i];
203  shared_max_error[0] = fmaxf(shared_max_error[0], shared_max_error[i]);
204  shared_has_error[0] |= shared_has_error[i];
205  shared_has_nonzero[0] |= shared_has_nonzero[i];
206  }
207 
208  // Single atomic update per block (reduces contention from O(errors) to O(blocks))
209  if(shared_has_error[0])
210  {
211  atomicAdd(&result->error_count, shared_error_count[0]);
212  atomicMax(&result->max_error, shared_max_error[0]);
213  }
214  // Update all_zero flag: if no nonzero values found, mark as all zero
215  if(!shared_has_nonzero[0])
216  {
217  atomicMin(&result->all_zero, 1);
218  }
219  else
220  {
221  atomicMin(&result->all_zero, 0);
222  }
223  }
224 }
225 
226 // Host-side wrapper for GPU verification with explicit tolerances
227 // Returns GpuVerifyResult with detailed error information
228 template <typename T>
229 GpuVerifyResult gpu_verify(const void* device_result,
230  const void* reference_result,
231  float rtol,
232  float atol,
233  std::size_t size,
234  hipStream_t stream = nullptr)
235 {
236  // Allocate result buffer on device
237  GpuVerifyDeviceResult* result_dev;
238  hip_check_error(hipMalloc(&result_dev, sizeof(GpuVerifyDeviceResult)));
239 
240  // Initialize result struct
241  GpuVerifyDeviceResult result_host;
242  result_host.error_count = 0; // No errors yet
243  result_host.max_error = 0.0f; // No error observed
244  result_host.all_zero = 1; // Start assuming all zeros (will be cleared if nonzero found)
246  hipMemcpy(result_dev, &result_host, sizeof(GpuVerifyDeviceResult), hipMemcpyHostToDevice));
247 
248  // Launch kernel with grid-stride loop
249  // Use 65535 as max grid size (hardware limit for grid dimension in x)
250  // Grid-stride loop handles any tensor size regardless of grid dimensions
251  constexpr int block_size = 256;
252  int grid_size = std::min<int>(65535, (size + block_size - 1) / block_size);
253 
254  gpu_verify_kernel<T>
255  <<<grid_size, block_size, 0, stream>>>(static_cast<const T*>(device_result),
256  static_cast<const T*>(reference_result),
257  rtol,
258  atol,
259  static_cast<long long>(size),
260  result_dev);
261 
262  hip_check_error(hipGetLastError());
263 
264  // Synchronize the stream to ensure kernel completion before reading results
265  hip_check_error(hipStreamSynchronize(stream));
266 
267  // Get result
269  hipMemcpy(&result_host, result_dev, sizeof(GpuVerifyDeviceResult), hipMemcpyDeviceToHost));
270 
271  // Free device memory
272  hip_check_error(hipFree(result_dev));
273 
274  // Build and return result struct
275  GpuVerifyResult result;
276  result.error_count = result_host.error_count;
277  result.max_error = result_host.max_error;
278  result.total = size;
279  result.all_zero = (result_host.all_zero == 1);
280 
281  return result;
282 }
283 
284 // Forward declaration of gpu_reduce_max
285 template <typename T>
286 float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream = nullptr);
287 
288 // Host-side wrapper for GPU verification with automatic tolerance computation
289 // Computes max value on GPU, then computes tolerances and verifies
290 // Returns GpuVerifyResult with detailed error information
291 template <typename OutDataType,
292  typename ComputeDataType = OutDataType,
293  typename AccDataType = ComputeDataType>
294 GpuVerifyResult gpu_verify(const void* device_result,
295  const void* reference_result,
296  int number_of_accumulations,
297  std::size_t size,
298  hipStream_t stream = nullptr)
299 {
300  // Compute max absolute value on GPU (only 4 bytes transferred!)
301  double max_abs_value =
302  static_cast<double>(gpu_reduce_max<OutDataType>(reference_result, size, stream));
303 
304  // Compute tolerances based on data types and accumulation count
305  float rtol = compute_relative_tolerance<ComputeDataType, OutDataType, AccDataType>(
306  number_of_accumulations);
307 
308  float atol = 0.0f;
309  // Only compute absolute tolerance for supported types
310  using F16 = ck::half_t;
311  using BF16 = ck::bhalf_t;
312  using F32 = float;
313 
314  if constexpr((std::is_same_v<ComputeDataType, F16> || std::is_same_v<ComputeDataType, BF16> ||
315  std::is_same_v<ComputeDataType, F32>) &&
316  (std::is_same_v<OutDataType, F16> || std::is_same_v<OutDataType, BF16> ||
317  std::is_same_v<OutDataType, F32>) &&
318  (std::is_same_v<AccDataType, F16> || std::is_same_v<AccDataType, BF16> ||
319  std::is_same_v<AccDataType, F32>))
320  {
321  atol = static_cast<float>(
322  ck::utils::get_absolute_threshold<ComputeDataType, OutDataType, AccDataType>(
323  max_abs_value, number_of_accumulations));
324  }
325 
326  // Call the explicit tolerance version
327  return gpu_verify<OutDataType>(device_result, reference_result, rtol, atol, size, stream);
328 }
329 
330 // GPU reduction kernel for computing max(abs(data))
331 // This is an internal kernel called only by gpu_reduce_max() wrapper.
332 //
333 // Assumption: Block size is 256
334 template <typename T>
335 __global__ void
336 gpu_reduce_max_kernel(const T* __restrict__ data, long long size, float* __restrict__ max_val)
337 {
338  constexpr int block_size = 256;
339  __shared__ float shared_max[block_size];
340 
341  long long idx = blockIdx.x * blockDim.x + threadIdx.x;
342  long long stride = blockDim.x * gridDim.x;
343 
344  float local_max = 0.0f;
345 
346  for(long long i = idx; i < size; i += stride)
347  {
348  float val = fabsf(type_convert<float>(data[i]));
349  local_max = fmaxf(local_max, val);
350  }
351 
352  shared_max[threadIdx.x] = local_max;
353  __syncthreads();
354 
355  // Block-level reduction: 256 -> 128 -> 64 -> 32
356  for(unsigned int s = block_size / 2; s >= 32; s >>= 1)
357  {
358  if(threadIdx.x < s)
359  {
360  shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
361  }
362  __syncthreads();
363  }
364 
365  // Final reduction of remaining 32 elements in thread 0
366  if(threadIdx.x == 0)
367  {
368  for(int i = 1; i < 32; ++i)
369  {
370  shared_max[0] = fmaxf(shared_max[0], shared_max[i]);
371  }
372 
373  // Single atomic update per block
374  atomicMax(max_val, shared_max[0]);
375  }
376 }
377 
378 // Host-side wrapper for GPU max reduction
379 // Computes max(abs(data)) and returns as float
380 // Only transfers 4 bytes (the final max value) instead of entire tensor
381 template <typename T>
382 float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream)
383 {
384  if(size == 0)
385  {
386  return 0.0f;
387  }
388 
389  // Allocate device memory for result
390  float* max_dev;
391  hip_check_error(hipMalloc(&max_dev, sizeof(float)));
392 
393  // Initialize to zero
394  float init_val = 0.0f;
395  hip_check_error(hipMemcpy(max_dev, &init_val, sizeof(float), hipMemcpyHostToDevice));
396 
397  // Launch reduction kernel
398  // Use 1024 blocks max for reduction to balance occupancy vs. grid-stride iterations
399  // For very large tensors (>256M elements), grid-stride loop handles the remainder
400  constexpr int block_size = 256;
401  int grid_size = std::min<int>(1024, (size + block_size - 1) / block_size);
402 
403  gpu_reduce_max_kernel<T><<<grid_size, block_size, 0, stream>>>(
404  static_cast<const T*>(device_buffer), static_cast<long long>(size), max_dev);
405 
406  hip_check_error(hipGetLastError());
407 
408  // Synchronize if using default stream
409  if(stream == nullptr)
410  {
411  hip_check_error(hipDeviceSynchronize());
412  }
413 
414  // Copy result to host (only 4 bytes!)
415  float max_host;
416  hip_check_error(hipMemcpy(&max_host, max_dev, sizeof(float), hipMemcpyDeviceToHost));
417 
418  // Free device memory
419  hip_check_error(hipFree(max_dev));
420 
421  return max_host;
422 }
423 
424 } // namespace profiler
425 } // namespace ck
__global__ void gpu_verify_kernel(const T *__restrict__ device_result, const T *__restrict__ reference_result, float rtol, float atol, long long size, GpuVerifyDeviceResult *result)
Definition: gpu_verification.hpp:127
float gpu_reduce_max(const void *device_buffer, std::size_t size, hipStream_t stream=nullptr)
Definition: gpu_verification.hpp:382
float compute_relative_tolerance(const int number_of_accumulations=1)
Definition: gpu_verification.hpp:62
__global__ void gpu_reduce_max_kernel(const T *__restrict__ data, long long size, float *__restrict__ max_val)
Definition: gpu_verification.hpp:336
GpuVerifyResult gpu_verify(const void *device_result, const void *reference_result, float rtol, float atol, std::size_t size, hipStream_t stream=nullptr)
Definition: gpu_verification.hpp:229
float F32
32-bit floating point (single precision) type
Definition: check_err.hpp:33
ck_tile::bf16_t BF16
16-bit brain floating point type
Definition: check_err.hpp:31
ck_tile::half_t F16
16-bit floating point (half precision) type
Definition: check_err.hpp:29
int32_t I32
32-bit signed integer type
Definition: check_err.hpp:37
int8_t I8
8-bit signed integer type
Definition: check_err.hpp:35
Definition: ck.hpp:270
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:12
signed short int16_t
Definition: stdint.h:122
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: gpu_verification.hpp:112
int all_zero
Definition: gpu_verification.hpp:115
unsigned long long error_count
Definition: gpu_verification.hpp:113
float max_error
Definition: gpu_verification.hpp:114
Definition: gpu_verification.hpp:22
std::size_t total
Definition: gpu_verification.hpp:25
float max_error
Definition: gpu_verification.hpp:24
bool all_zero
Definition: gpu_verification.hpp:26
unsigned long long error_count
Definition: gpu_verification.hpp:23
void print_error_summary() const
Definition: gpu_verification.hpp:41
float error_percentage() const
Definition: gpu_verification.hpp:33