0. intro
在 pd 分离中, mooncake 会在 prefill 和 decode 上都调用 batch_register_memory 把注册一大块 buffer 给 kv cache了,后续在注册的地址上完成ibv_post_send()。
P 进来 300 个 8k request 的时候,假设每个BatchSize 是 8,模型是 48 层,128 的 hidden_head,4个头,bf16。于是乎250GB kv cache 要传。mooncake 则是一个 batch 内所有 request 和 attention 算出来 kv cache 都会 pass,Prefill 会等待 batch 的所有 kv cache 计算完毕(靠.wait()保证),然后 batch_transfer_sync_write()一下把这个 kv cache 传给 decode,然后走 tcp 告诉 decode 我传完了。外面互相走 tcp交换的信息还有 decode 告诉 prefill 我要哪个请求和哪个blockId已经 ready等等。
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""MooncakeConnector does not save explicitly."""
pass
def wait_for_save(self):
pass上面所说的,在一个 batch 内大致就是传输 6GB 左右 cache 给 decode,走一次同步的 batch_transfer_sync_write 操作。
调用栈
# vllm 的 mooncake connector
engine.batch_transfer_sync_write()
└─ TransferEnginePy::batchTransferSync()
└─ engine_->submitTransfer(batch_id, entries)
└─ MultiTransport::submitTransfer()
└─ RdmaTransport::submitTransferTask()
├─ 每个 TransferRequest 切 64KB Slice
├─ 按本地 RNIC context 分组
└─ context->submitPostSend()
└─ WorkerPool::submitPostSend()
├─ 查 remote rkey / peer_nic_path
├─ 放入 worker shard queue
└─ transferWorker() 2 个线程
├─ 建 endpoint/QP handshake
├─ endpoint->submitPostSend()
│ └─ ibv_post_send(qp, wr_chain)
└─ performPollCq()
└─ ibv_poll_cq() → Slice markSuccess()- 每个 rank 默认 QP 是 2,在
RdmaEndPoint::construct函数内ibv_create_qp出来两个 QP。 - 生产: 具体的传输会透传到 rdma 的
submitTransferTask实现,这里for loop 去除 request,再 for loop 一个 request 内的 slice(每个 slice 切成 64K),然后每个 slice 提交给对应的 NIC的RdmaContext::submitPostSend()。 - 消费: WorkerPool::submitPostSend() → 缓冲到 8 个 shard 队列 → 2 个 transferWorker 线程消费 → 每个线程进 performPostSend() → 调到 RdmaEndPoint::submitPostSend() → ibv_post_send。
1. 结构体设计
传输层
一个 nic 就是一个 rdmaContext,纯面向 NIC 对象设计,这个 NIC 上有多个 worker,以及 ibv 需要的所有字段,在外面需要传输的时候只需要描述要搬什么给 worker 就行。
RdmaTransport
└── context_list_ : vector<shared_ptr<RdmaContext>> ← 每张 NIC 一个
└── RdmaContext (per NIC)
├── device_name_ "mlx5_3" ← NIC 的 key
├── context_ ibv_context* ← libibverbs 句柄
├── pd_ ibv_pd*
├── cq_list_ vector<RdmaCq>
├── endpoint_store_
└── worker_pool_ shared_ptr<WorkerPool> ← 内嵌