Meet Flash-KMeans: An IO-Aware, Exact K-Means That Runs Over 200× Faster Than FAISS on GPUs

Editor
9 Min Read


k-means has been an offline tool for decades. You run it once to preprocess data, then move on. A team of researchers from UC Berkeley and UT Austin released Flash-KMeans, a new open-source library that targets a different setting. Modern AI pipelines now call k-means inside training and inference loops. At that frequency, latency per call matters more than theoretical FLOPs.

Flash-KMeans is an IO-aware implementation of standard Lloyd’s k-means. It does not change the math, and it does not approximate. It only restructures how the algorithm moves data on a GPU. On an NVIDIA H200, the research team reported up to 17.9× end-to-end speedup over the best baseline. Against NVIDIA cuML they report 33×. Against FAISS they report over 200×.

What is Flash-KMeans

Flash-KMeans is a batched k-means library written in Triton GPU kernels. It ships under Apache 2.0 and installs with pip install flash-kmeans.

The output is mathematically identical to standard Lloyd’s k-means. The speedup comes from kernel-level dataflow, not from skipping work. That separates it from algorithmic methods like triangle-inequality pruning or coreset sampling.

A standard Lloyd iteration has two stages. The assignment stage computes each point’s distance to every centroid, then picks the nearest. The update stage averages the points in each cluster to form new centroids. Both stages are simple arithmetic. On GPUs, both are bottlenecked by memory, not compute.

The Two Bottlenecks It Attacks

The first bottleneck is the assignment stage. Standard code builds a full distance matrix D of shape N×K in High Bandwidth Memory (HBM). It writes the matrix, then reads it back to run argmin. For N=65536, K=1024, d=128, B=32, the distance math takes 2.6ms. Writing and consuming D takes about 23ms. The matrix is the cost, not the arithmetic.

Flash-KMeans replaces this with FlashAssign. The design borrows from FlashAttention. FlashAssign streams tiles of points and centroids from HBM into on-chip SRAM. It fuses distance computation with an online argmin. The full N×K matrix is never materialized. This cuts the dominant IO complexity from O(NK) to O(Nd + Kd). At the kernel level, FlashAssign reaches up to 21.2×. In one case it cut assignment from 122.5ms to 5.8ms.

The second bottleneck is the centroid update stage. Standard code uses scatter-style atomic adds. Each thread adds its point into a shared sum buffer keyed by cluster id. Many threads hit the same ‘hot’ cluster at once. That causes atomic contention and hardware serialization. The research team measured only 50 GB/s effective bandwidth here on an H200.

Flash-KMeans replaces this with Sort-Inverse Update. It sorts the 1D assignment vector by cluster id using argsort. Identical cluster ids then form contiguous segments. Each thread block reduces a segment on-chip, then issues one atomic add per segment. The heavy point matrix is never physically permuted. Atomic operations drop from (O((K+NBN)d))(O((K + \frac{N}{B_N})d)) . The kernel reaches up to 6.3×.

Benchmark

The research team test it on an H200 with CUDA 12.8, FP16 data, and d=128. They sweep N, K, and batch size B. They compare against four optimized baselines: fast_pytorch_kmeans, fastkmeans, cuML, and FAISS.

Comparison Reported speedup Workload context
End-to-end vs best baseline up to 17.9× N=8M, K=1024 (large N, small K)
vs NVIDIA cuML 33× industry library
vs FAISS over 200× industry library
FlashAssign kernel up to 21.2× N=1M, K=8192 (assignment)
Sort-Inverse Update kernel up to 6.3× N=33M, K=4096 (update)
Out-of-core, large scale up to 10.5× N=400M, K=16384 vs fastkmeans

One failure mode matters for context. Standard PyTorch implementations run out of memory in large-K regimes. They cannot materialize the N×K matrix. FAISS is the industry-standard library under many production vector-search systems.

The library also runs out-of-core. On one billion points (K=32768, d=128), it finishes an iteration in 41.4s, against 261.8s for the baseline. It uses chunked stream overlap to hide PCIe transfer behind compute. A cache-aware compile heuristic also cuts tuning overhead by up to 175×, within 0.3% of tuned speed.

MTP Interactive Explainer

Marktechpost · Interactive Explainer

Flash-KMeans: exact k-means, rebuilt around GPU memory

Same Lloyd’s math as standard k-means — faster only because of dataflow. Run clustering live, watch the update bottleneck, and size the IO it removes.

17.9×end-to-end vs best baseline

33×vs NVIDIA cuML

200×+vs FAISS

1Bpoints, out-of-core

1 · Live clustering

2 · Update contention

3 · IO calculator





Iteration0

Centroid shift

Statusidle

This runs real Lloyd’s k-means in your browser on 2-D points. The algorithm is identical to what Flash-KMeans accelerates — only the GPU dataflow differs. Each step = one assignment + one centroid update.

Press play. Standard scatter-update serializes when blocks write the same “hot” centroid (red stalls). Sort-Inverse Update sorts cluster IDs first, so each block merges contiguous segments with one atomic add — no conflict.


Standard atomicsO(N·d)

Sort-Inverse atomicsO((K+N/B)·d)

Measured std bandwidth50 GB/s

Kernel speedup6.3×

Standard updates issue one atomic add per token. Many threads hit the same centroid at once, causing contention. Sorting by cluster ID turns scatters into segment-level reductions in on-chip memory.

Standard — materialize N×K matrix, O(NK)

FlashAssign — stream inputs, O(Nd+Kd)

less HBM traffic for the assignment step (theoretical)

Use Cases

Faster exact k-means changes what you can run online, not just offline.

  • Vector search indexing: FAISS builds its search indices with k-means. Faster k-means lets you re-index as data shifts, instead of rebuilding overnight.
  • Sparse attention routing: Routing Transformers and Tactic cluster tokens to route attention. Millisecond k-means makes this viable inside the inference loop.
  • KV-cache compression: ClusterKV clusters tokens in semantic space to compress the cache. Cheaper clustering makes per-layer, per-step compression practical.
  • Low-bit KV quantization: Recent methods cluster KV entries into codebooks, repeatedly. Faster clustering shrinks that preprocessing cost.
  • Diffusion Transformers: Sparse VideoGen2 calls batched k-means during forward passes. It permutes tokens by semantic similarity to exploit sparsity.

Using It

The API mirrors faiss and sklearn. The call below clusters a batched (B, N, d) tensor.

import torch
from flash_kmeans import batch_kmeans_Euclid

x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(
    x, n_clusters=1000, tol=1e-4, verbose=True
)

A scikit-learn-style interface is also available.

from flash_kmeans import FlashKMeans

km = FlashKMeans(d=128, k=8192, niter=100)
labels = km.fit_predict(large_cpu_tensor)  # device=None uses all visible GPUs

The kernel auto-dispatches by shape and dtype. A small-D path handles d≤512. A split-D path handles larger d without materializing the distance matrix. Multi-GPU runs trigger automatically for large-N data held in CPU memory.

Key Takeaways

  • Flash-KMeans is exact, not approximate — same Lloyd’s math, sped up purely by GPU dataflow.
  • FlashAssign fuses distance + online argmin, cutting assignment IO from O(NK) to O(Nd+Kd) — up to 21.2×.
  • Sort-Inverse Update sorts cluster IDs into segments, replacing scatter atomics — up to 6.3×.
  • Reports up to 17.9× end-to-end, 33× over cuML, and over 200× over FAISS on an H200.
  • Scales out-of-core to one billion points and cuts tuning overhead up to 175×.

Check out the Paper and RepoAlso, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us


Share this Article
Please enter CoinGecko Free Api Key to get this plugin works.