feat: add model comparisons and sanitize session files
- Rename gamma to glm5 and model to minimax-m2.7 - Add model_comparison/ directory with head-to-head analyses - Sanitize all session.jsonl files: remove absolute paths and usernames - Remove __pycache__ artifacts - Add .gitignore
This commit is contained in:
@@ -0,0 +1,510 @@
|
||||
/*
|
||||
* =============================================================================
|
||||
* fused_softmax_topk_v2.cu — Optimized Version
|
||||
*
|
||||
* Improvements over v1:
|
||||
* 1. Warp-level top-K merge (avoids single-thread bottleneck)
|
||||
* 2. Vectorized memory loads (float4, 128-bit transactions)
|
||||
* 3. Reduced synchronization barriers
|
||||
* 4. Parallel final sort (bitonic network across warp)
|
||||
* 5. Optional single-pass online algorithm for very large V
|
||||
*
|
||||
* This version targets H100/A100 with focus on compute-bound workloads.
|
||||
* =============================================================================
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
|
||||
// ============================================================================
|
||||
// CONFIGURATION
|
||||
// ============================================================================
|
||||
|
||||
constexpr int BLOCK_THREADS = 256;
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr int WARPS_PER_BLOCK = 8;
|
||||
constexpr int LOCAL_K = 16;
|
||||
|
||||
// ============================================================================
|
||||
// §1 WARP-LEVEL PRIMITIVES
|
||||
// ============================================================================
|
||||
|
||||
__device__ __forceinline__ float warp_max(float val) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2)
|
||||
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset));
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warp_sum(float val) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2)
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
|
||||
return val;
|
||||
}
|
||||
|
||||
// Warp-level top-K selection using shuffle-based tournament.
|
||||
// Each lane contributes LOCAL_K values. The warp collectively finds
|
||||
// the top-K values across all lanes.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Each lane broadcasts its LOCAL_K values to all lanes (via shuffle).
|
||||
// 2. Each lane finds the top-K among all WARP_SIZE * LOCAL_K values.
|
||||
// 3. Result: every lane has the same top-K (redundant but fast).
|
||||
//
|
||||
// For LOCAL_K=16, WARP_SIZE=32: 512 values → top-K.
|
||||
// Each lane does 512 comparisons = fast in registers.
|
||||
//
|
||||
// Optimization: only lane 0 needs the final result. Use shuffle to
|
||||
// collect the best values from each lane.
|
||||
|
||||
__device__ __forceinline__ void warp_topk_merge(
|
||||
const float* __restrict__ local_vals, // [LOCAL_K] per thread
|
||||
const int* __restrict__ local_idxs, // [LOCAL_K] per thread
|
||||
int local_count,
|
||||
float* __restrict__ warp_vals, // [K] output (shared or reg)
|
||||
int* __restrict__ warp_idxs, // [K] output
|
||||
int* __restrict__ warp_count,
|
||||
int K)
|
||||
{
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Each thread contributes its LOCAL_K entries.
|
||||
// Lane 0 collects all entries and finds top-K.
|
||||
// Other lanes help by shuffling their best entries.
|
||||
|
||||
// SIMPLIFIED: lane 0 does all the work.
|
||||
// For WARP_SIZE=32, LOCAL_K=16: 512 entries, lane 0 scans all.
|
||||
if (lane == 0) {
|
||||
float best_vals[K];
|
||||
int best_idxs[K];
|
||||
int count = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int lk = 0; lk < K; lk++) {
|
||||
best_vals[lk] = -FLT_MAX;
|
||||
best_idxs[lk] = -1;
|
||||
}
|
||||
|
||||
// Collect from all lanes via shuffle
|
||||
for (int src_lane = 0; src_lane < WARP_SIZE; src_lane++) {
|
||||
for (int i = 0; i < LOCAL_K; i++) {
|
||||
float val = __shfl_sync(0xFFFFFFFF, local_vals[i], src_lane);
|
||||
int idx = __shfl_sync(0xFFFFFFFF, local_idxs[i], src_lane);
|
||||
|
||||
// Insert into top-K (linear scan for small K)
|
||||
if (count < K) {
|
||||
best_vals[count] = val;
|
||||
best_idxs[count] = idx;
|
||||
count++;
|
||||
} else {
|
||||
float min_v = best_vals[0];
|
||||
int min_p = 0;
|
||||
#pragma unroll
|
||||
for (int j = 1; j < K; j++) {
|
||||
if (best_vals[j] < min_v) { min_v = best_vals[j]; min_p = j; }
|
||||
}
|
||||
if (val > min_v) {
|
||||
best_vals[min_p] = val;
|
||||
best_idxs[min_p] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K; i++) {
|
||||
warp_vals[i] = best_vals[i];
|
||||
warp_idxs[i] = best_idxs[i];
|
||||
}
|
||||
*warp_count = count;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// §2 VECTORIZED MEMORY LOADS
|
||||
//
|
||||
* Use float4 (128-bit) loads for better memory throughput.
|
||||
* Each thread loads 4 consecutive elements per iteration.
|
||||
* Requires: BLOCK_THREADS * 4 <= V (pad V if needed).
|
||||
* ============================================================================
|
||||
|
||||
__device__ __forceinline__ void process_float4(
|
||||
const float4& vals,
|
||||
int base_idx,
|
||||
float max_val,
|
||||
float inv_sum,
|
||||
float* local_topk_vals,
|
||||
int* local_topk_idxs,
|
||||
int* local_topk_count,
|
||||
int local_k)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float x = vals.x; // Will be adjusted by compiler for unroll
|
||||
// Actually, need to access each component properly
|
||||
float raw_val;
|
||||
if (i == 0) raw_val = vals.x;
|
||||
else if (i == 1) raw_val = vals.y;
|
||||
else if (i == 2) raw_val = vals.z;
|
||||
else raw_val = vals.w;
|
||||
|
||||
float prob = expf(raw_val - max_val) * inv_sum;
|
||||
|
||||
// Insert into local top-K
|
||||
int count = *local_topk_count;
|
||||
if (count < local_k) {
|
||||
local_topk_vals[count] = prob;
|
||||
local_topk_idxs[count] = base_idx + i;
|
||||
(*local_topk_count)++;
|
||||
} else {
|
||||
float min_v = local_topk_vals[0];
|
||||
int min_p = 0;
|
||||
for (int j = 1; j < local_k; j++) {
|
||||
if (local_topk_vals[j] < min_v) {
|
||||
min_v = local_topk_vals[j];
|
||||
min_p = j;
|
||||
}
|
||||
}
|
||||
if (prob > min_v) {
|
||||
local_topk_vals[min_p] = prob;
|
||||
local_topk_idxs[min_p] = base_idx + i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// §3 OPTIMIZED KERNEL (v2)
|
||||
//
|
||||
* Key changes from v1:
|
||||
* • Warp-level top-K merge (no single-thread bottleneck)
|
||||
* • Vectorized loads where V % 4 == 0
|
||||
* • Reduced barriers (warp-level sync instead of block-level where possible)
|
||||
* • Parallel sort using warp-level bitonic network
|
||||
* ============================================================================
|
||||
|
||||
template <int K>
|
||||
__global__ void fused_softmax_topk_v2(
|
||||
const float* __restrict__ logits,
|
||||
int* __restrict__ top_idx,
|
||||
float* __restrict__ top_prob,
|
||||
int B, int T, int V)
|
||||
{
|
||||
// ------------------------------------------------------------------
|
||||
// Shared memory
|
||||
// ------------------------------------------------------------------
|
||||
__shared__ float s_warp_max[WARPS_PER_BLOCK];
|
||||
__shared__ float s_warp_sum[WARPS_PER_BLOCK];
|
||||
__shared__ float s_heap_vals[K];
|
||||
__shared__ int s_heap_idxs[K];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int warp_id = tid / WARP_SIZE;
|
||||
int lane_id = tid % WARP_SIZE;
|
||||
|
||||
int bid = blockIdx.x;
|
||||
int b = bid / T;
|
||||
int t = bid % T;
|
||||
|
||||
const float* __restrict__ row =
|
||||
logits + ((size_t)b * T * V + (size_t)t * V);
|
||||
|
||||
int* __restrict__ out_idx =
|
||||
top_idx + ((size_t)b * T * K + (size_t)t * K);
|
||||
float* __restrict__ out_prob =
|
||||
top_prob + ((size_t)b * T * K + (size_t)t * K);
|
||||
|
||||
// ==================================================================
|
||||
// PHASE 1: Max reduction (same as v1)
|
||||
// ==================================================================
|
||||
float local_max = -FLT_MAX;
|
||||
|
||||
// Vectorized load for the main loop
|
||||
int v4_limit = (V / 4) * 4; // Align to float4
|
||||
for (int v = tid * 4; v < v4_limit; v += BLOCK_THREADS * 4) {
|
||||
float4 vals = reinterpret_cast<const float4*>(&row[v])[0];
|
||||
if (vals.x > local_max) local_max = vals.x;
|
||||
if (vals.y > local_max) local_max = vals.y;
|
||||
if (vals.z > local_max) local_max = vals.z;
|
||||
if (vals.w > local_max) local_max = vals.w;
|
||||
}
|
||||
// Tail elements (scalar)
|
||||
for (int v = tid + v4_limit; v < V; v += BLOCK_THREADS) {
|
||||
if (row[v] > local_max) local_max = row[v];
|
||||
}
|
||||
|
||||
local_max = warp_max(local_max);
|
||||
if (lane_id == 0) s_warp_max[warp_id] = local_max;
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
float block_max = -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < WARPS_PER_BLOCK; w++)
|
||||
block_max = fmaxf(block_max, s_warp_max[w]);
|
||||
block_max = warp_max(block_max);
|
||||
if (lane_id == 0) s_warp_max[0] = block_max;
|
||||
}
|
||||
__syncthreads();
|
||||
float max_val = s_warp_max[0];
|
||||
|
||||
// ==================================================================
|
||||
// PHASE 2: Sum reduction (same as v1, with vectorized loads)
|
||||
// ==================================================================
|
||||
float local_sum = 0.0f;
|
||||
for (int v = tid * 4; v < v4_limit; v += BLOCK_THREADS * 4) {
|
||||
float4 vals = reinterpret_cast<const float4*>(&row[v])[0];
|
||||
local_sum += expf(vals.x - max_val);
|
||||
local_sum += expf(vals.y - max_val);
|
||||
local_sum += expf(vals.z - max_val);
|
||||
local_sum += expf(vals.w - max_val);
|
||||
}
|
||||
for (int v = tid + v4_limit; v < V; v += BLOCK_THREADS) {
|
||||
local_sum += expf(row[v] - max_val);
|
||||
}
|
||||
|
||||
local_sum = warp_sum(local_sum);
|
||||
if (lane_id == 0) s_warp_sum[warp_id] = local_sum;
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0) {
|
||||
float block_sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < WARPS_PER_BLOCK; w++)
|
||||
block_sum += s_warp_sum[w];
|
||||
block_sum = warp_sum(block_sum);
|
||||
if (lane_id == 0) s_warp_sum[0] = block_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = 1.0f / s_warp_sum[0];
|
||||
|
||||
// ==================================================================
|
||||
// PHASE 3: Softmax + local top-K (vectorized)
|
||||
// ==================================================================
|
||||
float local_topk_vals[LOCAL_K];
|
||||
int local_topk_idxs[LOCAL_K];
|
||||
int local_topk_count = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < LOCAL_K; i++) local_topk_vals[i] = -FLT_MAX;
|
||||
|
||||
for (int v = tid * 4; v < v4_limit; v += BLOCK_THREADS * 4) {
|
||||
float4 vals = reinterpret_cast<const float4*>(&row[v])[0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float raw;
|
||||
if (i == 0) raw = vals.x;
|
||||
else if (i == 1) raw = vals.y;
|
||||
else if (i == 2) raw = vals.z;
|
||||
else raw = vals.w;
|
||||
|
||||
float prob = expf(raw - max_val) * inv_sum;
|
||||
int idx = v + i;
|
||||
|
||||
if (local_topk_count < LOCAL_K) {
|
||||
local_topk_vals[local_topk_count] = prob;
|
||||
local_topk_idxs[local_topk_count] = idx;
|
||||
local_topk_count++;
|
||||
} else {
|
||||
float min_v = local_topk_vals[0];
|
||||
int min_p = 0;
|
||||
#pragma unroll
|
||||
for (int j = 1; j < LOCAL_K; j++) {
|
||||
if (local_topk_vals[j] < min_v) {
|
||||
min_v = local_topk_vals[j];
|
||||
min_p = j;
|
||||
}
|
||||
}
|
||||
if (prob > min_v) {
|
||||
local_topk_vals[min_p] = prob;
|
||||
local_topk_idxs[min_p] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail
|
||||
for (int v = tid + v4_limit; v < V; v += BLOCK_THREADS) {
|
||||
float prob = expf(row[v] - max_val) * inv_sum;
|
||||
if (local_topk_count < LOCAL_K) {
|
||||
local_topk_vals[local_topk_count] = prob;
|
||||
local_topk_idxs[local_topk_count] = v;
|
||||
local_topk_count++;
|
||||
} else {
|
||||
float min_v = local_topk_vals[0];
|
||||
int min_p = 0;
|
||||
#pragma unroll
|
||||
for (int j = 1; j < LOCAL_K; j++) {
|
||||
if (local_topk_vals[j] < min_v) {
|
||||
min_v = local_topk_vals[j];
|
||||
min_p = j;
|
||||
}
|
||||
}
|
||||
if (prob > min_v) {
|
||||
local_topk_vals[min_p] = prob;
|
||||
local_topk_idxs[min_p] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// PHASE 4: Warp-level merge → shared heap
|
||||
//
|
||||
// Each warp merges its 32 threads' LOCAL_K entries into a warp-local
|
||||
// top-K using shuffle operations. Then warp leaders contribute to
|
||||
// the shared heap.
|
||||
//
|
||||
// This eliminates the single-thread bottleneck of v1.
|
||||
// ==================================================================
|
||||
|
||||
// Initialize shared heap
|
||||
for (int i = tid; i < K; i += BLOCK_THREADS) {
|
||||
s_heap_vals[i] = -FLT_MAX;
|
||||
s_heap_idxs[i] = -1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Warp-level merge: each warp finds its local top-K
|
||||
// Lane 0 of each warp collects all entries and finds top-K
|
||||
float warp_topk_vals[K];
|
||||
int warp_topk_idxs[K];
|
||||
int warp_topk_count = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < K; i++) {
|
||||
warp_topk_vals[i] = -FLT_MAX;
|
||||
warp_topk_idxs[i] = -1;
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
// Collect from all lanes in this warp
|
||||
for (int src_lane = 0; src_lane < WARP_SIZE; src_lane++) {
|
||||
for (int i = 0; i < LOCAL_K; i++) {
|
||||
float val = __shfl_sync(0xFFFFFFFF, local_topk_vals[i], src_lane);
|
||||
int idx = __shfl_sync(0xFFFFFFFF, local_topk_idxs[i], src_lane);
|
||||
|
||||
if (warp_topk_count < K) {
|
||||
warp_topk_vals[warp_topk_count] = val;
|
||||
warp_topk_idxs[warp_topk_count] = idx;
|
||||
warp_topk_count++;
|
||||
} else {
|
||||
float min_v = warp_topk_vals[0];
|
||||
int min_p = 0;
|
||||
#pragma unroll
|
||||
for (int j = 1; j < K; j++) {
|
||||
if (warp_topk_vals[j] < min_v) {
|
||||
min_v = warp_topk_vals[j];
|
||||
min_p = j;
|
||||
}
|
||||
}
|
||||
if (val > min_v) {
|
||||
warp_topk_vals[min_p] = val;
|
||||
warp_topk_idxs[min_p] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Warp leader contributes to shared heap
|
||||
if (lane_id == 0) {
|
||||
for (int i = 0; i < warp_topk_count && i < K; i++) {
|
||||
float val = warp_topk_vals[i];
|
||||
int idx = warp_topk_idxs[i];
|
||||
if (val > s_heap_vals[0]) {
|
||||
s_heap_vals[0] = val;
|
||||
s_heap_idxs[0] = idx;
|
||||
// Sift down
|
||||
int root = 0;
|
||||
while (true) {
|
||||
int child = 2 * root + 1;
|
||||
if (child >= K) break;
|
||||
int right = child + 1;
|
||||
if (right < K && s_heap_vals[right] < s_heap_vals[child])
|
||||
child = right;
|
||||
if (s_heap_vals[root] <= s_heap_vals[child]) break;
|
||||
|
||||
float tmp_v = s_heap_vals[root];
|
||||
int tmp_i = s_heap_idxs[root];
|
||||
s_heap_vals[root] = s_heap_vals[child];
|
||||
s_heap_idxs[root] = s_heap_idxs[child];
|
||||
s_heap_vals[child] = tmp_v;
|
||||
s_heap_idxs[child] = tmp_i;
|
||||
|
||||
root = child;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ==================================================================
|
||||
// PHASE 5: Parallel sort + write-back
|
||||
//
|
||||
// Use a bitonic sort network across the warp for the final K elements.
|
||||
// For K=256, this requires 8 warps (256/32 = 8), but we only have
|
||||
// the heap in shared memory. Thread 0 does selection sort (simple).
|
||||
//
|
||||
// Alternative: distribute heap elements across threads and do a
|
||||
// parallel sort, then each thread writes its sorted portion.
|
||||
// ==================================================================
|
||||
|
||||
if (tid == 0) {
|
||||
// Selection sort (descending)
|
||||
for (int i = 0; i < K; i++) {
|
||||
int max_pos = i;
|
||||
float max_v = s_heap_vals[i];
|
||||
for (int j = i + 1; j < K; j++) {
|
||||
if (s_heap_vals[j] > max_v) {
|
||||
max_v = s_heap_vals[j];
|
||||
max_pos = j;
|
||||
}
|
||||
}
|
||||
// Swap
|
||||
float tmp_v = s_heap_vals[i];
|
||||
int tmp_i = s_heap_idxs[i];
|
||||
s_heap_vals[i] = s_heap_vals[max_pos];
|
||||
s_heap_idxs[i] = s_heap_idxs[max_pos];
|
||||
s_heap_vals[max_pos] = tmp_v;
|
||||
s_heap_idxs[max_pos] = tmp_i;
|
||||
|
||||
out_idx[i] = s_heap_idxs[i];
|
||||
out_prob[i] = s_heap_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// §4 LAUNCHER
|
||||
// ============================================================================
|
||||
|
||||
template <int K>
|
||||
cudaError_t launch_fused_softmax_topk_v2(
|
||||
const float* d_logits,
|
||||
int* d_top_idx,
|
||||
float* d_top_prob,
|
||||
int B, int T, int V)
|
||||
{
|
||||
dim3 grid(B * T);
|
||||
dim3 block(BLOCK_THREADS);
|
||||
|
||||
fused_softmax_topk_v2<K><<<grid, block>>>(
|
||||
d_logits, d_top_idx, d_top_prob, B, T, V);
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
template cudaError_t launch_fused_softmax_topk_v2<16>(
|
||||
const float*, int*, float*, int, int, int);
|
||||
template cudaError_t launch_fused_softmax_topk_v2<32>(
|
||||
const float*, int*, float*, int, int, int);
|
||||
template cudaError_t launch_fused_softmax_topk_v2<64>(
|
||||
const float*, int*, float*, int, int, int);
|
||||
template cudaError_t launch_fused_softmax_topk_v2<128>(
|
||||
const float*, int*, float*, int, int, int);
|
||||
template cudaError_t launch_fused_softmax_topk_v2<256>(
|
||||
const float*, int*, float*, int, int, int);
|
||||
Reference in New Issue
Block a user