GPT-4やClaude 3のような大規模言語モデル(LLM)とチャットをする際、たとえ50ページのPDFを分析させている最中であっても、私たちはほぼ即座のレスポンスを期待します。しかし、会話が長くなるにつれて、返答が遅くなったり(レイテンシの増加)、ローカルでモデルを動かしている場合には「Out of Memory (OOM)」エラーが発生したりすることに気づくかもしれません。
このシナリオにおける「犯人」であり、同時に「英雄」でもあるのが、KVキャッシュ(Key-Value Cache)です。これは自己回帰モデルにおいて推論速度を最適化するための最も重要なコンポーネントですが、その代償としてメモリを大量に消費します。KVキャッシュを理解することは、LLMのデプロイを最適化したい人、RAGパイプラインを構築する人、あるいは単に「なぜロングコンテキストモデルには巨大なGPUが必要なのか」を知りたい人にとって不可欠です。
核心となる概念:なぜキャッシュが必要なのか
KVキャッシュを理解するには、LLMがどのようにテキストを生成するかを見る必要があります。LLMは自己回帰的(Autoregressive)です。つまり、トークン(単語の一部)を一度に一つずつ生成します。そして、文脈を維持するために、新しいトークンはそれ以前のすべてのトークンに依存して生成されます。
アテンション(Attention)の問題点
Transformerアーキテクチャでは、「自己注意機構(Self-Attention)」によって、モデルが過去のトークンを振り返り、現在のトークンの意味を決定します。
数学的には、すべてのトークンに対して、モデルは以下の3つのベクトルを計算します:
- クエリ ($Q$): 現在のトークンが何を探しているか。
- キー ($K$): 現在のトークンが自分自身をどう定義しているか。
- バリュー ($V$): そのトークンの実際の内容や情報の価値。
キャッシュがない場合、100番目のトークンを生成するには、モデルはトークン1から99までの $K$ と $V$ 行列を再計算する必要があります。101番目を生成するには、1から100までを再び計算し直さなければなりません。これでは計算量が二乗的に増加($O(N^2)$)し、計算リソースの無駄遣いとなります。
解決策
過去のトークンのKey(キー)とValue(バリュー)ベクトルは、一度計算されれば変化しません。そのため、これらをGPUメモリ(VRAM)に保存(キャッシュ)しておくことができます。新しいトークンを生成する際、モデルは新しいトークンの $Q, K, V$ だけを計算し、過去の $K$ と $V$ はキャッシュから取得して利用します。
プロセスの可視化
KVキャッシュを実装することで、推論プロセスは以下のように変化します。
graph TD
A["ユーザープロンプト(入力)"] --> B["プレフィルフェーズ(事前充填)"]
B --> C["全入力トークンのKとVを計算"]
C --> D["KVキャッシュに保存 (VRAM)"]
D --> E{"生成ループ(デコーディング)"}
E --> F["新規トークン入力"]
F --> G["新規トークンのQ, K, V のみ計算"]
G --> H["過去のK, Vをキャッシュから取得"]
H --> I["アテンション実行 (新規Q vs キャッシュ済みK, V)"]
I --> J["次のトークンを出力"]
J --> K["新しいK, Vをキャッシュに追加"]
K --> E
「コード」で見る:基本的なKVキャッシュの実装
深層学習ライブラリはこれを自動的に処理してくれますが、Pythonのコードでロジックを見ると仕組みが明確になります。以下の簡略化されたスニペットは、フォワードパス(順伝播)中にキャッシュがどのように初期化され、更新されるかを示しています。
import torch
class SimpleKVCache:
def __init__(self, max_seq_len, hidden_dim):
# 空のキャッシュ用テンソルを初期化
# 形状: [バッチサイズ, シーケンス長, 隠れ層の次元数]
self.k_cache = torch.zeros(1, max_seq_len, hidden_dim)
self.v_cache = torch.zeros(1, max_seq_len, hidden_dim)
self.current_pos = 0
def update(self, key_state, value_state):
"""
新しいトークンのKeyとValueの状態(State)でキャッシュを更新する
"""
seq_len = key_state.shape[1]
# 事前に確保したバッファに新しいKとVを挿入
self.k_cache[:, self.current_pos : self.current_pos + seq_len, :] = key_state
self.v_cache[:, self.current_pos : self.current_pos + seq_len, :] = value_state
self.current_pos += seq_len
# current_posまでの有効なデータを含むキャッシュのスライス(断片)を返す
return (
self.k_cache[:, :self.current_pos, :],
self.v_cache[:, :self.current_pos, :]
)
# 使用例のシミュレーション
# モデルから最新トークンの new_k と new_v を取得したと仮定
# cache_engine.update(new_k, new_v)
ステップ・バイ・ステップ:推論のライフサイクル
- プレフィルフェーズ(The Prefill Phase):
プロンプトを送信した瞬間、モデルはすべての入力トークンを並列処理します。プロンプト全体の $K$ と $V$ 行列を計算し、キャッシュに保存します。この段階は通常、計算バウンド(Compute-bound)です(GPUがどれだけ速く計算できるかがボトルネックになります)。 - デコーディングフェーズ(トークン生成):
モデルはトークンを一つずつ生成するモードに切り替わります。キャッシュされたデータを取得し、単一の新しいトークンのみを処理します。- ボトルネックの転換: このフェーズは通常、メモリ帯域幅バウンド(Memory-bandwidth bound)です。GPUは計算そのものよりも、巨大なKVキャッシュをVRAMから計算コアへ移動させることに多くの時間を費やします。
- メモリの爆発的増加:
コンテキスト長が長くなるにつれて、KVキャッシュは線形に増加します。- 注:係数の「2」は、KとVの両方を保存するためです。
- コンテキストウィンドウの限界:
最終的に、KVキャッシュがGPUの利用可能なVRAMを埋め尽くします。こうなるとモデルはこれ以上トークンを処理できなくなり、コンテキストの切り捨て(Truncation)やクラッシュが発生します。
メモリへの影響分析
以下の表は、ロングコンテキストモデル(例:128kコンテキスト)を動かす際、モデルの重み(Weights)だけでなく、KVキャッシュがいかに大量のVRAMを必要とするかを示しています。
前提条件:Llama-3-70Bモデル、Float16精度(1要素あたり2バイト)と概算。
| コンテキスト長 (トークン) | KVキャッシュの概算サイズ (GB) | 一般的なハードウェアへの影響 |
|---|---|---|
| 4,096 | ~0.6 GB | 無視できるレベル。ほとんどのGPUで動作可能。 |
| 32,000 | ~5.0 GB | かなり大きい。12GB以上のVRAMを持つカードが必要。 |
| 128,000 | ~20.0 GB | 致命的。モデルの重みと合わせると、RTX 4090 (24GB) でも溢れる。 |
KVキャッシュ最適化のプロ向けヒント
LLMをデプロイしたりアプリケーションを構築する場合、標準的なキャッシングだけでは不十分なことがあります。以下の高度なテクニックを使用して、メモリ使用量を管理しましょう。
- PagedAttention (vLLM) の利用: 従来のキャッシングは連続したメモリブロックを予約するため、断片化(フラグメンテーション)や無駄が生じます。vLLM は、OSの仮想メモリ管理に着想を得た「PagedAttention」を使用し、キーとバリューを不連続なメモリブロックに保存します。これにより、スループットが最大24倍向上します。
- Grouped Query Attention (GQA): Llama 3などの最新モデルはGQAを採用しています。これは、すべてのクエリ(Query)ヘッドに対して個別のKey/Valueヘッドを持つのではなく、複数のクエリでKVヘッドを共有する仕組みです。これにより、パフォーマンスの低下を最小限に抑えつつ、キャッシュサイズを劇的に(多くの場合1/8に)削減できます。
- KVキャッシュの量子化 (Quantization): KVキャッシュ自体を圧縮(例:FP16からINT8やFP8へ)することも可能です。これにより、同じVRAM使用量で最大コンテキスト長を実質2倍にできます。
- FlashAttention: バックエンドで FlashAttention-2 が有効になっているか確認してください。これはアテンション機構の読み書き操作を最適化し、デコーディングフェーズにおけるメモリ帯域幅のボトルネックを軽減します。
KVキャッシュは、VRAMを犠牲にして計算速度を得る、現代のLLMにおける「効率化のエンジン」です。私たちが今日享受しているスムーズなチャット体験を実現しているのはこの仕組みですが、同時にロングコンテキスト・アプリケーションにおける最大の制約要因でもあります。PagedAttentionや量子化などの技術を活用し、このキャッシュをうまく「手なずける」ことで、より小さなハードウェアでより大きなコンテキストを扱うことが可能になります。
