maxsim

MaxSim

A fast, memory-efficient exact MaxSim kernel for late-interaction retrieval and reranking (PyLate / ColBERT-style), packaged as a Hugging Face kernels repo.

The kernel computes

score(q, d) = sum_i max_j  <q_i, d_j>

over a batch of (query, document) pairs without materialising the full [Lq, Ld] similarity matrix. It tiles over document tokens, keeps running per-q-token maxima in shared memory, and reduces those into the per-pair score.

Install

uv add kernels        # or: pip install kernels

Usage

Two entry points. The packed/ragged form is the canonical kernel-facing API; the padded form is an ergonomic wrapper around the same kernel for the common batched-reranking case.

Packed (ragged segments)

import torch
from kernels import get_kernel

maxsim = get_kernel("erikkaum/maxsim", version=1)

scores = maxsim.score_pairs_packed(
    queries,           # [total_q_tokens, dim]
    query_offsets,     # [num_queries + 1], int32/int64
    documents,         # [total_d_tokens, dim]
    document_offsets,  # [num_documents + 1], int32/int64
    pair_query_ids,    # [num_pairs]
    pair_document_ids, # [num_pairs]
)
# scores.shape == [num_pairs], dtype == float32

Padded (batched reranking)

scores = maxsim.score_candidates_padded(
    queries,        # [B, Lq, dim]
    documents,      # [B, candidates, Ld, dim]
    query_lengths,  # [B]
    doc_lengths,    # [B, candidates]
)
# scores.shape == [B, candidates], dtype == float32

A pure-PyTorch reference (maxsim.maxsim_reference, maxsim.score_pairs_packed_reference, maxsim.score_candidates_padded_reference) ships alongside for tests and benchmarks.

Supported

Backend Devices Input dtypes Accum / output
Metal Apple Silicon (MPS) fp32 / fp16 / bf16 fp32
CUDA sm_80, sm_86, sm_89 (Ampere + Lovelace) fp32 / fp16 / bf16 fp32

dim is generic; the fast simdgroup_matrix / WMMA paths fire when dim % 8 == 0 (Metal) / dim % 16 == 0 (CUDA), which covers the typical embedding sizes (64 / 96 / 128).

Benchmarks

Three padded-API workloads taken straight from the design plan, comparing the kernel to a vectorised but naïve PyTorch baseline that materialises the [Lq, Ld] similarity matrix.

Apple M3 Pro (Metal, fp16, dim=128)

Workload Kernel Naive Speedup
SmallRerank — B=32, C=10, Lq=32, Ld=180 0.45 ms 1.44 ms 3.18×
HeavyRerank — B=32, C=100, Lq=32, Ld=256 4.34 ms 16.63 ms 3.83×
LongDocStress — B=8, C=16, Lq=64, Ld=1024 1.69 ms 3.70 ms 2.19×

NVIDIA CUDA (fp16, dim=128)

Workload A10G (sm_86) L4 (sm_89) A100 (sm_80)
SmallRerank 2.28× 2.05× 2.80×
HeavyRerank 4.48× 5.18× 5.29×
LongDocStress 3.41× 6.21× 1.89×

(A100's naive einsum is so well-tuned by cuBLAS that LongDocStress barely benefits there; on memory-bandwidth-bound GPUs like L4 the kernel pulls ahead significantly.)

Reproduce with:

kernels benchmark erikkaum/maxsim

Limitations

  • No backward pass (forward-only scoring kernel for now).
  • No argmax-position output (just the score).
  • CUDA fast path requires dim % 16 == 0 and Lq_max % 16 == 0; other shapes fall back to a correctness-preserving scalar kernel.
  • Hopper (sm_90) is supported via PTX forward-compat but doesn't yet use WGMMA — Ampere/Lovelace gets the best tuning.

Source / contribute

Source: https://github.com/erikkaum/maxsim.

License: Apache-2.0.

Downloads last month
-
apache-2.0
Supported hardwares new
CUDA 8.08.6
GPU
L40s
48GB
GPU
L40
48GB
GPU
L20
48GB
GPU
L4
24GB
GPU
RTX 6000 Ada
48GB
GPU
RTX 5880 Ada
48GB
RTX
RTX 5000 Ada
32GB
GPU
RTX 4500 Ada
24GB
RTX
RTX 4000 Ada
20GB
RTX
RTX 4000 SFF Ada
20GB
GPU
RTX 2000 Ada
16GB
GPU
RTX A6000
48GB
GPU
RTX A5000
8GB
GPU
RTX A5000 Max-Q
16GB
GPU
RTX A5000 Mobile
16GB
GPU
RTX A4000
16GB
GPU
RTX A4000 Max-Q
8GB
GPU
RTX A4000 Mobile
8GB
GPU
RTX A3000 Mobile
6GB
GPU
RTX A2000
6GB
GPU
RTX A2000 Embedded
4GB
GPU
RTX A2000 Max-Q
4GB
GPU
RTX A2000 Mobile
4GB
GPU
A100
80GB
GPU
A40
48GB
GPU
A30
24GB
GPU
A10
24GB
GPU
A2
16GB
RTX
RTX 4090
24GB
RTX
RTX 4090D
24GB
RTX
RTX 4090 Mobile
16GB
RTX
RTX 4080 SUPER
16GB
RTX
RTX 4080
16GB
RTX
RTX 4080 Mobile
12GB
RTX
RTX 4070
12GB
RTX
RTX 4070 Mobile
8GB
RTX
RTX 4070 Ti
12GB
RTX
RTX 4070 Super
12GB
RTX
RTX 4070 Ti Super
16GB
RTX
RTX 4060
8GB
RTX
RTX 4060 Ti
8GB
RTX
RTX 4090 Laptop
16GB
RTX
RTX 4080 Laptop
12GB
RTX
RTX 4070 Laptop
8GB
RTX
RTX 4060 Laptop
8GB
RTX
RTX 4050 Laptop
6GB
RTX
RTX 3090
24GB
RTX
RTX 3090 Ti
24GB
RTX
RTX 3080
12GB
RTX
RTX 3080 Ti
12GB
RTX
RTX 3080 Mobile
8GB
RTX
RTX 3070
8GB
RTX
RTX 3070 Ti
8GB
RTX
RTX 3070 Ti Mobile
8GB
RTX
RTX 3060 Ti
8GB
RTX
RTX 3060
12GB
RTX
RTX 3060 Mobile
6GB
RTX
RTX 3050 Mobile
4GB
GPU
RTX 2050 Mobile
4GB
Jetson
Jetson AGX Orin 64GB
64GB
Jetson
Jetson AGX Orin 32GB
32GB
Jetson
Jetson Orin NX 16GB
16GB
Jetson
Jetson Orin NX 8GB
8GB
Jetson
Jetson Orin Nano 8GB
8GB
Jetson
Jetson Orin Nano 4GB
4GB
Metal
Apple Silicon
Apple MacBook Neo
8GB
Apple Silicon
Apple M1
8GB
Apple Silicon Pro
Apple M1 Pro
16GB
Apple Silicon Max
Apple M1 Max
16GB
Apple Silicon Ultra
Apple M1 Ultra
16GB
Apple Silicon
Apple M2
8GB
Apple Silicon Pro
Apple M2 Pro
16GB
Apple Silicon Max
Apple M2 Max
32GB
Apple Silicon Ultra
Apple M2 Ultra
64GB
Apple Silicon
Apple M3
8GB
Apple Silicon Pro
Apple M3 Pro
18GB
Apple Silicon Max
Apple M3 Max
36GB
Apple Silicon Ultra
Apple M3 Ultra
96GB
Apple Silicon
Apple M4
16GB
Apple Silicon Pro
Apple M4 Pro
24GB
Apple Silicon Max
Apple M4 Max
36GB
Apple Silicon
Apple M5
16GB
Apple Silicon Pro
Apple M5 Pro
24GB
Apple Silicon Max
Apple M5 Max
36GB
OS
macoslinux
Arch
x86_64aarch64