Modulo to the Resuce


One of the problems on KernelBench that Tensara poses as a challenge is the Argmax of a Tensor. Simply put, we find the index of the maximum element in dim d and reduce the array along that dimension, replacing multiple values with the index. Sounds confusing ? Here is a visualization to help you:

>>> A
tensor([[[35, 21, 24, 11],
         [ 8, 48, 39, 48],
         [10, 42, 20, 16]],

        [[16, 37,  6, 46],
         [40, 44, 44, 47],
         [31,  7, 40, 10]]])
>>> A.argmax(dim=0)
tensor([[0, 1, 0, 1],
        [1, 0, 1, 0],
        [1, 0, 1, 0]])
>>> A.argmax(dim=1)
tensor([[0, 1, 1, 1],
        [1, 1, 1, 1]])
>>> A.argmax(dim=2)
tensor([[0, 1, 1],
        [3, 3, 2]])

In each of the argmax calculations, we went from a 3-D to a 2-D tensor, replacing a row or a column of values with a single one

Argmax image1 Argmax image2

This is the visualization of a 3-D dimensional Tensor. Dim 0 is one of the 2 2-D matrices, Dim 1 are columns and Dim 2 are the rows.

To calculate Argmax across Dim 1 (the column), we find the largest element in A[n][i][0], A[n][i][1], A[n][i][2], A[n][i][3], where n is [0..1], i is [0..2]

So how can we parallelize calculating argmax via CUDA ?

Note that my kernel might not be the most optimal implementation, but I still found the algorithm interesting enough to write a blog post about it.

To effectively parallelize, let’s figure out how many threads we must launch and what memory addresses the thread must read and update to calculate the argmax.

From the image above, what happens if we chose dim=1 to calculate Argmax ? Since we are calculating an argmax per column and there are 4 columns in each of A[0] & A[1], we will calculate 2x4 = 8 values of shape (2, 4).

Let’s choose dim=2, that is calculating argmax per row. With 3 rows in each of A[0] and A[1], we calculate 2x3=6 values of shape (2, 3)

With dim=1, we compare A[0][0][0] with A[1][0][0], A[0][0][1] with A[1][0][1] and thus endup calculating 3x4=12 values of dim (1, 12).

Notice a pattern? For argmax along dim d (where d is 0, 1 or 2), we simply remove shape[d] and multiply the remaining sizes in the shape array. For A, with shape = (2, 3, 4), argmax along dim=1 is shape[0]*shape[2], for dim=2: shape[0]*shape[1] and for dim=3: shape[1]*shape[2]

Great, now we know how many values to calculate, we can launch a CUDA thread for each output value to be calculate. What are the input locations then that each thread must compare to find the argmax ?

For argmax along a specific dimension of a tensor, we’re essentially:

  1. Breaking the tensor into “slices” perpendicular to our target dimension
  2. Processing each slice independently

I found the algorithm hard to understand without a concrete example, so let’s start with an example and generalize it into an algorithm.

Let’s apply our algorithm to a 5D tensor with shape [2, 3, 4, 5, 6] and argmax over dimension 2

The first thing we do is to calculate the stride for each of the dimensions, starting from the last one (dim=4). Since the data in the memory is laid out as a 1-D array, the stride of dim d will tell us how many elements we will jump if we increment the value by 1. For dim=4, stride[4] = 1, as everytime we jump, we move to the next element (think of a jump as moving from column 1 -> 2 in the same row).

For dim=3, stride[3]=shape[4] (think of this as moving from one row to the next, we move by the number of elements / row).

For dim=2, stride[3]=shape[3]*stride[4] (think of this as jumping from one 2-D matrix to the next) Starting stride[n-1] = 1,

Generalizing, we can calculate stride[i]= shape[i]*stride[i+1] going from i=n-2..0

Next up we calculate weights. The weights are used later to map the 1-D index of a CUDA thread to the corresponding multi-dimensional index in the input tensor. The calculation of weights is similar to calculating the stride, except we skip over the dimension along which we are calculating the Argmax

    // dim is the dimension we are calculating argmax over
    // Calculate weights for dimension mapping
    int* h_weights = new int[ndim];
    int running_weight = 1;
    
    // Process from right to left, skipping the target dimension
    for (int i = ndim-1; i >= 0; i--) {
        if (i == dim) {
            h_weights[i] = 0;  // Not used for the target dimension
            continue;
        }
        
        h_weights[i] = running_weight;
        running_weight *= h_shape[i];
    }

For Argmax(dim=2), our h_weights array will look like this : [90, 30, 0, 6, 1]

Why is the h_weights array value calculated as such, especialy with h_weights[dim] = 0 ? Each CUDA thread will find the max of all the shape[dim] values. Inside the kernel, we will plug the value of i ranging from 0..shape[dim-]-1 in a loop to calculate the index of each element that the thread will be comparing.

Also, note that there are fewer threads that the number of elements since we “remove” shape[dim] before calculating threads=prod(shape). Therefore the weights array will help each thread calculate the right multi-dimensional index in a dimensional space where dim is removed from shape

Let us continue using the example above:

Total elements: 2 × 3 × 4 × 5 × 6 = 720

Elements along dimension 2: 4

Number of slices: 720 ÷ 4 = 180

Let’s map thread ID 42 to 4D indices (excluding dimension 2). Recall that h_weights = [90, 30, 0, 6, 1]

We are launching a 1-D thread grid of 1-D thread blocks. We thus have a linear thread index and we derive the index in each dimension to help us iterate over the input array (and store it in the array indices)

We start with the linear index of the thread as the remainder and for each dimension, calculate its position by

  1. Dividing by the “weight” of each dimension
  2. Taking the remainder for the next division

For dim 0: 42 ÷ 90 = 0 remainder 42

For dim 1: 42 ÷ 30 = 1 remainder 12

For dim 3: 12 ÷ 6 = 2 remainder 0

For dim 4: 0 ÷ 1 = 0 remainder 0

This is exactly like converting between number systems, where each position has a different “base” (the size of the dimension).

From here, it is easy to calculate the 1-D index of each element that a thread must access and compare to calculate the Argmax.

    // Find the maximum value along the specified dimension
    float max_val = -INFINITY;
    int max_idx = -1;
    
    // Iterate through all positions along the target dimension
    for (int i = 0; i < shape[dim]; i++) {
        // Set the index for the dimension we're iterating over
        indices[dim] = i;
        
        // Calculate the flat index into the input array
        int flat_idx = 0;
        for (int j = 0; j < ndim; j++) {
            flat_idx += indices[j] * strides[j];
        }
        
        // Check if this value is larger than our current max
        if (input[flat_idx] > max_val) {
            max_val = input[flat_idx];
            max_idx = i;
        }
    }

I tried to write down a paragraph to explain how this loop works, but I will be admit that I cannot describe it in words very well. The best explanation is something like this, if you are comparing along dim=2, then each element you have to compare is shape[3]*shape[4] elements away from one another. We used the generic formula position = i₁ × (D₂ × D₃ × ... × Dₙ) + i₂ × (D₃ × ... × Dₙ) + ... + iₙ where i is the value of each dim in the indices array that we calculated for each thread. I urge you to use Claude or any LLM and sit and work this with a pen and paper as drawing shapes on paper is worth 10,000 or more digital characters.

Storing the output is straight forward. Since we reduce the output size to prod(shape)/shape[dim] outputs, each thread thus outputs only one value (after iterating over shape[dim] values). Thus the kernel’s final output is

// Iterate over the dimension we're reducing
   for (int i = 0; i < shape[dim]; i++) {
      ...
   }

   output[tid] = max_idx;

Conclusion

This was a fun kernel exercise for me. It took me a few days of figuring out (and dealing with a stupid size calculation bug that was giving incorrect outputs for a long while before I fixed it) and I sat on writing this article for a long time due to other commitments. I am pasting the entire program in this post (as the leetgpu snippet sharing feature isn’t working yet). If you find any bugs or an improved algorithm, please do reach out to me as I am always on the lookout for advice from CUDA kernel programming masters

#include <cuda_runtime.h>
#include <stdio.h>

#define MAX_DIM 16
#define TOTAL_SIZE 2*3*4
__global__ void argmax(const float* input, int dim, int* output, int* shape, int ndim, int* strides, int* weights, int slices) {
   int tid = blockDim.x * blockIdx.x + threadIdx.x;

   if (tid >= slices) { return; }
   
   // Initialize indices for this particular slice
   int indices[MAX_DIM];
   for (int i = 0; i < ndim; i++) {
      indices[i] = 0;
   }
   
   // Convert linear index to multi-dimensional indices
   int rem = tid;
   for (int i = 0; i < ndim; i++) {  // Note: forward order
      if (i != dim) {
         int idx_i = rem / weights[i];
         indices[i] = idx_i;
         rem %= weights[i];
      }
   }

   
   // Find the maximum value along the specified dimension
   float max_val = -INFINITY;
   int max_idx = 0;  // Default to first element

   // Calculate the base flat index for this slice (excluding the dimension we're maxing over)
   int base_idx = 0;
   for (int j = 0; j < ndim; j++) {
      if (j != dim) {
         base_idx += indices[j] * strides[j];
      }
   }
   
   // Iterate over the dimension we're reducing
   for (int i = 0; i < shape[dim]; i++) {
      // Calculate flat index - base index plus offset in the dimension we're iterating
      // int flat_idx = base_idx + i * strides[dim];
      int flat_idx = base_idx + i * strides[dim];

      // Find maximum
      if (input[flat_idx] > max_val) {
         max_val = input[flat_idx];
         max_idx = i;  // Store the index along dim dimension
      }
   }

   output[tid] = max_idx;
}

// Note: input, output, shape are all device pointers to float32 arrays
void solution(const float* input, int dim, int* output, int* shape, int ndim) {    
   
  
   // Calculate strides (how many elements to move for each dimension)
   int* h_strides = (int*) malloc(ndim * sizeof(int));
   h_strides[ndim-1] = 1;
   for (int i = ndim-2; i >= 0; i--) {
      h_strides[i] = h_strides[i+1] * shape[i+1];
   }
   
   // Calculate total number of slices (elements in the output)
   int slices = 1;
   for (int i = 0; i < ndim; i++) {
      if (i != dim) {
         slices *= shape[i];
      }
   }
   
   int* h_weights = (int*) malloc(ndim * sizeof(int));
   for (int i = 0; i < ndim; i++) {
      h_weights[i] = 0;
   }
   int weight = 1;
   for (int i = ndim-1; i >= 0; i--) {
    if (i != dim) {
        h_weights[i] = weight;
        weight *= shape[i];
      } 
   }

   // Allocate and copy device memory
   int *d_strides = 0;
   int *d_weights = 0;
   cudaMalloc((void**)&d_strides, sizeof(int) * ndim);
   cudaMemcpy(d_strides, h_strides, sizeof(int) * ndim, cudaMemcpyHostToDevice);
   cudaMalloc((void**)&d_weights, sizeof(int) * ndim);
   cudaMemcpy(d_weights, h_weights, sizeof(int) * ndim, cudaMemcpyHostToDevice);

   // Launch kernel
   int threadsPerBlock = 256;
   int numBlocks = (slices + threadsPerBlock - 1) / threadsPerBlock;
   argmax<<<numBlocks, threadsPerBlock>>>(input, dim, output, shape, ndim, d_strides, d_weights, slices);
   
   
   cudaDeviceSynchronize();
   
   // Free memory
   cudaFree(d_strides);
   cudaFree(d_weights);
   free(h_strides);
   free(h_weights);
}


int main() {
        float* h_input = (float*)malloc(sizeof(float)*TOTAL_SIZE);
        for(int i=0;i<(TOTAL_SIZE);i++) {
          h_input[i] = (float)i;
        }
        int h_shapes[] = {2, 3, 4};
        float* d_input = 0;
        int* d_output = 0;
        int* d_shapes = 0;
        int items = 2*4;
        cudaMalloc((void**)&d_input, sizeof(float)*TOTAL_SIZE);
        cudaMemcpy(d_input, h_input, sizeof(float)*TOTAL_SIZE, cudaMemcpyHostToDevice);

        cudaMalloc((void**)&d_output, sizeof(float)*items);
        cudaMalloc((void**)&d_shapes, sizeof(h_shapes));
        cudaMemcpy(d_shapes, h_shapes, sizeof(h_shapes), cudaMemcpyHostToDevice);
        solution(d_input, 1, d_output, d_shapes, 3);
        int* output = (int *)malloc(sizeof(int)*TOTAL_SIZE);
        cudaMemcpy(output, d_output, sizeof(int)*items, cudaMemcpyDeviceToHost);
        for(int i=0;i<items;i++) {
            printf("%d ", output[i]);
        }
}