Infinite-LLM: Efficient LLM Service for Long Context with DistAttention and Distributed KV Cache

本文主要研究云端/分布式环境下LLM对于长上下文任务的KV Cache管理问题

Challenges to LLM serving on Cloud(Motivation)

Challenge 1: significant disparities in memory demands obstacles efficient model parallelism

如[^Table 1]所示。

[^Table 1]: LLaMA2-13B, KV Cache size with context legnth

Context length 10k 100k 500k 1000k
KV Cache size 8.19GB 81.9GB 409.6GB 819.2GB
Misc size 26GB 26GB 26GB 26GB

为了满足长上下文任务所需的大量 KV 缓存,必须增加 GPU 的数量。然而,其他层的张量维度并不会随着上下文长度的增加而扩展。因此,传统的模型并行方法在分布到更多 GPU 上时,会对这些层进行更细粒度的划分,从而导致资源利用效率降低,如图所示。

Fig.1

假设将模型分成4个部分,每个GPU处理1个部分,这种划分方式使得每个GPU的工作量较为均衡。

  • 细粒度划分:如果将模型进一步细分成8个部分,那么每个GPU只处理更小的一部分。此时,GPU之间的通信频率增加,负载不均衡问题也更加显著。

Challenge 2: dynamicity of KV Cache size leads to inefficient resource management in the cloud environment

由于LLM生成的自回归特性,最终Sequence长度不确定,因此,所需内存大小是动态且不可测的,这使得无法提前规划资源分配。

若对于某一实例,上下文所需的内存大小超出了当前GPU的容量:

  • 将整个任务迁移到有着更多GPU的设备实例(live migration)
  • 从一开始就分配更多计算资源:在短上下文情况造成资源浪费

Main Idea

为了解决Challenge 1,提出一个新的Attention算法,把原先对Attention的计算细分给当前分布式环境下的各个设备,从而把KV Cache分成小的子块来进行管理

为了解决Challenge 2,提出一个新的Model DistKV-LLM,进行 KV Cahce管理,协调所有实例 GPU 和 CPU 的内存使用。当一个 LLM 服务实例因 KV Cache增加而出现内存不足时,DistKV-LLM 会主动识别并从其他容量过剩的实例中借用可用内存空间。

DistAttention

Key: 将Attention的计算划分为 Micro Attentions (MAs), 每个MA都与一个子序列tokens及相应KV Cache对应这些MAs可以分别独立计算,最终attention结果通过聚合各MA的结果得到

MAij=exp(QiKjTmax(QiKjT))(1)MA_{ij} = \exp(Q_iK_j^T - \max(Q_iK_j^T)) \tag{1}

Attention(Q,K,V)=Reduce(Scale([MAij]j=1Bkv))=Reduce([exp(QiKjTmaxi)MAij]j=1Bkv)=j=1Bkvexp(QiKjTmaxi)MAijsumi\begin{align*} Attention(Q,K,V) &= Reduce(Scale([MA_{ij}]_{j=1}^{B_{kv}})) \\ &=Reduce([\exp(Q_iK_j^T - \max_i)MA_{ij}]_{j=1}^{B_{kv}}) &\tag{2} \\ &=\sum_{j=1}^{B_{kv}}\frac{\exp(Q_iK_j^T - \max_i)MA_{ij}}{sum_i} \end{align*}

maxi=max(max(QiK1T,,max(QiKBT)))sumi=exp(QiKjTmaxi)\begin{align*} max_i &=\max(\max(Q_iK_1^T,\dots,\max(Q_iK_B^T))) \\ sum_i &=\sum{\exp(Q_iK_j^T - \max_i)} \end{align*}

1(1)式为各MA单独计算公式,2(2)​​式为最终attention聚合计算公式

Standard Self-Attention:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

rBlock & rManager

实现DistAttention后,KV Caches就被分成了更小的单元,称为rBlocks

每个LLM service实例都有一个rBlock manager, 称为rManager,负责管理本地设备的所有rBlock。

实际上,当前实例上的GPU显存空间虚拟化为一个个大小固定的rBlock

rManager的作用如下:

  1. 维护一个逻辑rBlock到物理rBlock的映射表。
  2. 提供操作接口
    • 为新生成的KV Cache分配 rBlock,当没有足够显存分配时,启动借用程序从其他实例借用显存空间
    • 释放不需要的rBlock

Fig.2

rBlock中包含元数据(metadata):

  • rBlock ID, Instance ID: 表明该rBlock属于当前实例还是远程实例
  • device ID, physical ID: 表明该rBlock的物理地址(CPU或某一个GPU)

gManager & Contract Protocal

gManager是一个全局管理者,维护所有实例的显存信息。

各实例定期向gManager发送心跳信号(heartbeat),其中包含各实例剩余的显存空间大小。

Fig.3

当一个实例所有rBlock都已分配完毕,该实例上的rManager就要从其他实例上借用CPU或GPU显存,在此之前,作为一个debtor

  1. gManager初始化一个请求
  2. gManager查表Global Debt Ledger(定期排序)
  3. gManager向该debtor返回可能的creditor地址(原则:通讯cost最低,剩余空间最多)
  4. debtorgManager返回的creditor依次借用空间
  5. 直到有一个creditor成功借出空间(若都没有空间,返回1.)

Fragmented Memory Management

每个实例既是debtor又是creditor,由于上下文长度的动态变化,各实例会根据需要借入或借出内存(e.g. 处理长上下文时,实例所需空间可能会不断增大,需要从远程实例借用空间;处理短上下文时则会更快地释放空间,然后借给实例或分配给新请求)。

这种动态性导致了数据局部性的恶化,当实例频繁访问存储在远程实例的数据时,就会增加时延,降低吞吐量。

为此,提出DGFM, debt-graph-based fragmented memory management algorithm,将已经借出的空间交换为本实例空间(回路内结点)。为防止频繁交换,设定交换时所需的最小块数。

Fig.4


补充

简单介绍一下几种并行方法(parallelism)[^1][^2]:

  • Data Parallelism(数据并行)
  • Model Parallelism(模型并行)
    • Tensor Parallelism(层内)
    • Pipeline(流水线)Parallelism(层间)

Fig.5

Data parallelism

在各设备上都有一份完整模型参数,各设备彼此间可独立计算,给各设备的Input data切分后并行输入并计算。每隔一段时间(若干batch)后同步各设备上模型权重的梯度。

现在随着模型大小不断增大,单GPU显存已经无法容纳现在的LLM,因此有了Model Parallelism。

Model Parallelism

  1. Pipeline Parallelism:
    模型做层间划分(inter-layer parallelism)
    如上图,若原本有6层,一个GPU不够存,那么两个GPU分别存3层

  2. Tensor Parallelism:
    模型做层内划分(intra-layer parallelism)
    如图,将线性层按照行或列对权重划分。原本为 Y=W1W2XY = W_1W_2X 这里将W1W_1按列划分,W2W_2
    按行划分。每个GPU只需存一半的权重即可,最后通过All-Reduce补全结果后能继续下一层的计算。

Fig.6

Reference

[^1]: Li S, Liu H, Bian Z, et al. Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training[M].
arXiv, 2023.
[^2]: 深度学习并行训练算法一锅炖: DDP, TP, PP, ZeRO - marsggbo