K-Means 演算法大翻身:Flash-KMeans 如何在 GPU 上飆速 200 倍,革新 AI 資料處理流程?
編輯核心觀點
- ✦由柏克萊與德州大學奧斯汀分校研究團隊開發的 Flash-KMeans 函式庫,透過優化 GPU 資料流而非改變數學演算法,實現了 K-Means 演算法高達 200 倍以上的驚人加速。
- ✦這項創新主要解決了 K-Means 在 GPU 上遇到的兩大記憶體瓶頸:指派階段的距離矩陣具體化與更新階段的原子操作爭用,大幅降低了輸入/輸出(IO)成本。
- ✦Flash-KMeans 的高速運算能力,讓 K-Means 從傳統的離線預處理工具,轉變為向量搜尋、稀疏注意力路由及鍵值快取壓縮等現代 AI 流程中的即時核心組件。

K-Means 演算法數十年來一直是資料預處理的離線工具,通常僅執行一次。然而,隨著人工智慧(AI)技術的快速發展,K-Means 已被整合到訓練與推論的迴圈中,需要頻繁呼叫,這使得每次呼叫的延遲時間遠比理論上的浮點運算次數(FLOPs)更為關鍵。
為了解決這個痛點,來自加州大學柏克萊分校(UC Berkeley)與德州大學奧斯汀分校(UT Austin)的研究團隊,共同發布了一款名為 Flash-KMeans 的開源函式庫。這項創新並未改變標準勞埃德 K-Means(Lloyd’s k-means)演算法的數學基礎,也沒有採用近似方法,而是透過重新架構資料在圖形處理器(GPU)上的移動方式,實現了前所未有的加速。
Flash-KMeans:不改數學,只改資料流
Flash-KMeans 是一個基於 Triton GPU 核心(Triton GPU kernels)編寫的批次 K-Means 函式庫,採用 Apache 2.0 授權,可透過 pip install flash-kmeans 輕鬆安裝。其輸出結果與標準勞埃德 K-Means 在數學上完全一致,但速度提升的關鍵在於核心層級的資料流優化,而非跳過任何運算步驟。這也讓它與三角不等式剪枝(triangle-inequality pruning)或核心集取樣(coreset sampling)等演算法層面的加速方法有所區隔。
標準的勞埃德 K-Means 迭代包含兩個階段:
- 指派階段(Assignment Stage):計算每個資料點到所有質心(centroid)的距離,然後選擇最近的質心進行指派。
- 更新階段(Update Stage):將每個叢集(cluster)中的資料點取平均,形成新的質心。
這兩個階段都涉及簡單的算術運算,但在 GPU 上,它們的瓶頸往往在於記憶體存取,而非計算能力。
突破兩大記憶體瓶頸
Flash-KMeans 的高效能來自於它對 K-Means 演算法在 GPU 上遇到的兩大記憶體瓶頸的精準攻擊:
瓶頸一:指派階段的資料矩陣具體化
傳統的 K-Means 程式碼在指派階段,會在高頻寬記憶體(HBM)中建立一個完整的 N×K 距離矩陣 D(其中 N 是資料點數量,K 是叢集數量),先寫入矩陣,再讀回以執行最小值索引(argmin)操作。研究團隊指出,在特定條件下(N=65536, K=1024, d=128, B=32),距離計算本身僅需 2.6 毫秒,但寫入和讀取矩陣 D 卻耗費約 23 毫秒。這顯示了矩陣的輸入/輸出(IO)成本才是主要瓶頸。
Flash-KMeans 透過 FlashAssign 取代了這一過程。其設計靈感來自 FlashAttention,透過將資料點和質心的區塊(tiles)從 HBM 串流傳輸到晶片內靜態隨機存取記憶體(SRAM),並將距離計算與線上最小值索引操作融合。如此一來,完整的 N×K 矩陣從未被具體化,從根本上將主導性的 IO 複雜度從 O(NK) 降低到 O(Nd + Kd)。在核心層級,FlashAssign 的速度提升高達 21.2 倍,在一個案例中將指派時間從 122.5 毫秒縮短至 5.8 毫秒。
瓶頸二:質心更新階段的原子操作爭用
在質心更新階段,傳統程式碼使用分散式(scatter-style)的原子加法(atomic adds)。每個執行緒(thread)會將其資料點加到一個以叢集 ID 為鍵的共享總和緩衝區中。當許多執行緒同時寫入同一個「熱門」叢集時,會導致原子操作爭用(atomic contention)和硬體序列化(hardware serialization)。研究團隊在 NVIDIA H200 上測得的有效頻寬僅為每秒 50 GB。
Flash-KMeans 則採用 Sort-Inverse Update 方法。它首先使用排序索引(argsort)將一維指派向量按叢集 ID 排序,使相同的叢集 ID 形成連續的區段。每個執行緒區塊(thread block)在晶片內對一個區段進行歸約(reduce),然後為每個區段發出一個原子加法操作。這種方式避免了實體上重新排列龐大的資料點矩陣,並將原子操作次數從 O(N·d) 降低到 O((K + N/B)·d)。此核心的速度提升高達 6.3 倍。
效能實測:超越業界標準 200 倍
研究團隊在搭載 NVIDIA H200 GPU、CUDA 12.8 環境下,使用 FP16 資料和維度 d=128 進行了廣泛測試,並與四個優化過的基準函式庫(fast_pytorch_kmeans, fastkmeans, cuML, FAISS)進行比較。結果顯示,Flash-KMeans 展現了令人驚豔的效能提升:
- 端到端(end-to-end)效能:相較於最佳基準函式庫,在 N=800 萬、K=1024 的工作負載下,速度提升高達 17.9 倍。
- 對比 NVIDIA cuML:速度提升 33 倍。
- 對比 FAISS:相較於許多生產級向量搜尋系統採用的業界標準函式庫 FAISS,速度提升超過 200 倍。
- FlashAssign 核心:在 N=100 萬、K=8192 的指派任務中,速度提升高達 21.2 倍。
- Sort-Inverse Update 核心:在 N=3300 萬、K=4096 的更新任務中,速度提升高達 6.3 倍。
- 核心外(Out-of-core)大規模運算:在 N=4 億、K=16384 的資料集上,相較於 fastkmeans 速度提升高達 10.5 倍。處理 10 億個資料點(K=32768, d=128)的單次迭代僅需 41.4 秒,而基準函式庫則需 261.8 秒。
值得一提的是,標準的 PyTorch 實作在處理大規模 K 值時常因記憶體不足而失效,無法具體化 N×K 矩陣。Flash-KMeans 則能有效應對這些挑戰,甚至支援核心外運算,透過分塊串流重疊(chunked stream overlap)技術隱藏 PCIe 傳輸延遲,並利用快取感知編譯啟發式(cache-aware compile heuristic)將調優開銷降低多達 175 倍。
開啟 K-Means 即時應用的新可能
更快的精確 K-Means 演算法,將 K-Means 的應用範圍從傳統的離線處理拓展到即時線上運算,為現代 AI 流程帶來革命性的影響:
- 向量搜尋索引(Vector search indexing):FAISS 等系統利用 K-Means 建立搜尋索引。Flash-KMeans 的高速運算能力,讓資料變動時能即時重新索引,而非等待隔夜重建。
- 稀疏注意力路由(Sparse attention routing):在路由 Transformer(Routing Transformers)和 Tactic 等模型中,K-Means 用於叢集化 token 以路由注意力。毫秒級的 K-Means 運算使其在推論迴圈中變得可行。
- 鍵值快取壓縮(KV-cache compression):ClusterKV 等技術透過在語義空間中叢集化 token 來壓縮鍵值快取。更便宜的叢集運算使每層、每步驟的壓縮成為實用選項。
- 低位元鍵值量化(Low-bit KV quantization):近期方法反覆將鍵值條目叢集化為碼本(codebooks)。更快的叢集運算可大幅縮減預處理成本。
- 擴散式 Transformer 模型(Diffusion Transformers):Sparse VideoGen2 在前向傳播(forward passes)期間調用批次 K-Means,透過語義相似性排列 token 以利用稀疏性。
Flash-KMeans 的應用程式介面(API)設計與 faiss 和 scikit-learn 相似,易於開發者整合。例如,透過簡單的幾行 Python 程式碼,即可對批次張量進行叢集化,或使用類似 scikit-learn 的介面進行操作。
這項技術的重要性在於,它將 K-Means 從一個耗時的離線工具,轉變為現代 AI 系統中不可或缺的即時組件。透過優化底層的資料流,而非犧牲演算法的精確性,Flash-KMeans 為處理大規模資料集和實現更高效能的 AI 應用開闢了廣闊空間。



