0. 前言
nixl 的架构见0. nixl research,这里主要说 vllm 使用 nixl connector完成 kv cache 传输的过程。
结构
Decode 侧 (NixlConnectorWorker) Prefill 侧 (NixlConnectorWorker)
══════════════════════════════════════════════════ ══════════════════════════════════════════════
[启动阶段] register_kv_caches(kv_caches)
────────────────────────────────────────────────────────────────────────────────────────────────
两侧各自执行,互不感知:
1. 遍历所有层 kv_caches (layer_name -> Tensor):
- split_k_and_v=True(FA): K/V 分两个 region
- split_k_and_v=False(FlashInfer): K+V 合一个 region,后再 num_regions *=2
- 记录每层: base_addr, block_len_per_layer[i]
2. nixl_wrapper.register_memory(
[(base_addr, tensor_bytes, device_id, ""), ...],
memory_type="VRAM"
)
└─ 全部层的 KV tensor 一次性注册到 NIXL,变成 RDMA 可见内存
3. register_local_xfer_handler():
blocks_data = []
for i, base_addr in seen_base_addresses: ← 每个 region (每层 or K/V)
for block_id in range(num_blocks): ← 每个 block
addr = base_addr + block_id * block_len_per_layer[i]
blocks_data.append((addr, kv_block_len, device_id))
descs = nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
src_xfer_side_handle = nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
└─ src dlist: num_regions × num_blocks 个 desc,扁平列表
desc_index = region_i * num_blocks + block_id
4. 构造 NixlAgentMetadata:
{ engine_id, agent_metadata(RDMA连接信息),
kv_caches_base_addr[0..num_regions-1],
num_blocks, block_lens[0..num_regions-1],
kv_cache_layout, block_size }
封装为 NixlHandshakePayload { compat_hash, agent_metadata_bytes }
→ xfer_handshake_metadata (等待握手时对外暴露)
5. (Prefill 侧) Scheduler._nixl_handshake_listener_t 启动
ZMQ ROUTER bind: side_channel_host:side_channel_port+dp_rank
└─ 等待 Decode 侧来拉 agent metadata
[Prefill 完成推理,请求结束] request_finished()
────────────────────────────────────────────────────────────────────────────────────────────────
Prefill Scheduler:
request_finished(request, block_ids)
└─ params["do_remote_decode"]=True
delay_free_blocks = True
_reqs_need_send[req_id] = now + timeout
return (True, {
do_remote_prefill: True,
remote_block_ids: block_ids, ← Prefill 侧 block
remote_engine_id: self.engine_id,
remote_request_id: req_id,
remote_host: side_channel_host,
remote_port: side_channel_port,
tp_size: tp_size,
})
└─ 这包 kv_transfer_params 随 HTTP 响应
发回给 Decode 侧请求发起方
[Decode 侧收到 kv_transfer_params,请求入队]
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Scheduler:
get_num_new_matched_tokens():
params["do_remote_prefill"]=True
→ 返回 (len(prompt_tokens) - num_computed_tokens, async=True)
update_state_after_alloc():
params["do_remote_prefill"]=True + remote_block_ids 存在
→ _reqs_need_recv[req_id] = (request, local_block_ids)
params["do_remote_prefill"] = False ← 防止重复触发
build_connector_meta():
meta.reqs_to_recv[req_id] = ReqMeta(
local_block_ids,
remote=RemoteMeta(
block_ids=remote_block_ids, ← Prefill 侧的 block
engine_id=remote_engine_id,
request_id=remote_request_id,
host=remote_host,
port=remote_port,
)
)
meta.reqs_to_send = {req_id: expiry} ← Prefill 侧需要知道 Decode 拿走了
[每个 forward step] start_load_kv(metadata)
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker:
for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = _logical_to_kernel_block_ids(meta.local_block_ids)
meta.remote.block_ids = _logical_to_kernel_block_ids(meta.remote.block_ids)
_recving_metadata[req_id] = meta
if remote_engine_id NOT in _remote_agents:
_background_nixl_handshake(req_id, remote_engine_id, meta)
└─ ThreadPoolExecutor 提交 _nixl_handshake(), 继续
else:
_read_blocks_for_req(req_id, meta) ← 握手已完成,直接发
while not _ready_requests.empty():
_read_blocks_for_req(*_ready_requests.get_nowait()) ← 处理本轮握手刚完成的请求
[握手子流程] _nixl_handshake() (在后台线程)
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker (后台线程): Prefill Scheduler (_nixl_handshake_listener_t):
p_remote_rank = kv_topo.get_target_remote_rank(remote_tp_size)
ZMQ REQ connect → tcp://remote_host:remote_port
sock.setsockopt(RCVTIMEO, 5000ms)
msg = encode((GET_META_MSG, p_remote_rank))
sock.send(msg) ──────────→ ROUTER recv_multipart()
decode msg → (GET_META_MSG, target_tp_rank)
sock.send_multipart([identity, "",
encoded_data[target_tp_rank]])
└─ encoded_data = NixlHandshakePayload {
compat_hash,
agent_metadata_bytes: NixlAgentMetadata {
kv_caches_base_addr, ← 对端所有层 base addr
num_blocks,
block_lens, ← 每层 block_len
device_id,
kv_cache_layout,
block_size
}
}
handshake_bytes = sock.recv() ←────────────────────┘
验证 compat_hash
decode NixlAgentMetadata
add_remote_agent(metadata, p_remote_rank, remote_tp_size):
nixl_wrapper.add_remote_agent(metadata.agent_metadata) → remote_agent_name
blocks_data = []
for i, remote_base_addr in nixl_meta.kv_caches_base_addr: ← 每个 region (Prefill 侧的)
kv_block_len = get_backend_aware_kv_block_len(i) / tp_ratio
rank_offset = (tp_rank % tp_ratio) * kv_block_len ← 异构TP:切 kv_head 维
for block_id in range(nixl_meta.num_blocks): ← Prefill 侧每个 block
addr = remote_base_addr + block_id * block_lens[i] + rank_offset
blocks_data.append((addr, kv_block_len, remote_device_id))
descs = nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
dst_xfer_side_handles[engine_id] =
nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
└─ dst dlist: num_regions × remote_num_blocks 个 desc
desc_index = region_i * remote_num_blocks + block_id
完成后 _remote_agents[engine_id][rank] = remote_agent_name
→ Future done_callback 把 (req_id, meta) put 进 _ready_requests
[RDMA 传输核心] _read_blocks(local_block_ids, remote_block_ids, dst_engine_id, ...)
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker (在 start_load_kv 调用链上):
# 如果全量 prefix cache 命中 (0 块需要传)
if len(local_block_ids) == 0:
nixl_wrapper.send_notif(remote_agent_name, notif_msg=f"{remote_req_id}:{tp_ratio}")
return ← 只通知 Prefill 可以释放 blocks,不做实际 RDMA
# 对齐 block 数量(partial prefix hit 只传尾部未命中的块)
remote_block_ids = remote_block_ids[-len(local_block_ids):]
# 计算 desc index 列表 (一次覆盖所有层)
_get_block_descs_ids(engine_id, block_ids, layer_idx=None):
region_ids = arange(num_regions) # [0, 1, ..., num_regions-1]
descs_ids = region_ids[:, None] * num_blocks + block_ids[None, :]
return descs_ids.flatten()
└─ 长度 = num_regions × len(block_ids)
= (num_layers × 1 or 2) × num_blocks_to_transfer
每个元素是 dlist 里的索引,指向一个具体的 (addr, len) desc
local_descs_ids = _get_block_descs_ids(self.engine_id, local_block_ids)
remote_descs_ids = _get_block_descs_ids(dst_engine_id, remote_block_ids)
# 一次 NIXL RDMA READ 把所有层指定 blocks 全部传完
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
handle = nixl_wrapper.make_prepped_xfer(
"READ",
src_xfer_side_handle, ← Decode 本地 dlist
local_descs_ids, ← 本地哪些 desc (目标写入位置)
dst_xfer_side_handle, ← Prefill 远端 dlist
remote_descs_ids, ← 远端哪些 desc (数据来源)
notif_msg=notif_id ← RDMA 完成后自动发 notification 给 Prefill
)
nixl_wrapper.transfer(handle) ← 异步发起,立即返回
_recving_transfers[req_id].append(handle)
注意:notif_msg 是 NIXL 内置的 RDMA 完成通知机制,
RDMA 数据落地后 Prefill 侧 nixl_wrapper.get_new_notifs() 会收到
"{remote_request_id}:{tp_ratio}" 字符串
[forward 期间] wait_for_layer_load() / save_kv_layer()
────────────────────────────────────────────────────────────────────────────────────────────────
Decode 侧: pass (NIXL 不做 per-layer 操作,transfer 已在 start_load_kv 全量异步发出)
Prefill 侧: pass (KV 已在 register_kv_caches 时注册为 RDMA 可见,无需逐层 save)
[Prefill: host_buffer 模式下 wait_for_save()]
────────────────────────────────────────────────────────────────────────────────────────────────
if use_host_buffer and copy_blocks:
save_kv_to_host(metadata)
└─ for req in reqs_to_save:
copy_blocks(device_kv → host_xfer_buffers,
block_ids, "d2h") ← 阻塞 D2H
(TPU 等不支持直接 NIXL 注册的设备走此路径)
[完成轮询] get_finished() — 每个 forward step 调用
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker: Prefill Worker:
_pop_done_transfers(_recving_transfers): _get_new_notifs():
for req_id, handles in _recving_transfers: for notifs in nixl_wrapper.get_new_notifs():
for handle in handles: req_id, tp_ratio = notif.split(":")
state = nixl_wrapper.check_xfer_state(handle) consumer_notification_counts[req_id] += 1
if state == "DONE": if count == tp_ratio: ← 所有 D worker 都拉完
nixl_wrapper.get_xfer_telemetry(handle) notified_req_ids.add(req_id)
nixl_wrapper.release_xfer_handle(handle) del _reqs_to_send[req_id]
done_recving.add(req_id) └─ Prefill 可以释放这些 KV blocks
elif state == "PROC":
in_progress = True
else: ← 失败
_handle_failed_transfer(req_id, handle)
done_recving 返回后:
for req_id in done_recving:
meta = _recving_metadata.pop(req_id)
if use_host_buffer:
sync_recved_kv_to_device(req_id, meta) ← H2D copy
if enable_permute_local_kv:
permute_device_kv(meta.local_physical_block_ids)
if block_size_ratio > 1:
blocksize_post_process(...) ← 异构 block size 重排
[超时兜底] get_finished() 末尾
────────────────────────────────────────────────────────────────────────────────────────────────
Prefill Worker:
while _reqs_to_send:
req_id, expires = oldest entry
if now < expires: break
← 超过 VLLM_NIXL_ABORT_REQUEST_TIMEOUT
强制释放,done_sending.add(req_id)
(打 warning log)
[关键数据结构总结]
────────────────────────────────────────────────────────────────────────────────────────────────
src/dst xfer_side_handle 的 desc 布局 (以 FA 为例, num_layers=L, num_blocks=N):
dlist index: 0 1 ... N-1 N N+1 ... 2N-1 2N ... 2L*N-1
[layer0-K ] [layer0-V ] [layer1-K ... ]
block0 block1 ... blockN-1 block0 block1 ... blockN-1 ...
_get_block_descs_ids(engine, [b3, b7, b12]):
region_ids = [0, 1, 2, 3, ..., 2L-1]
descs = [0*N+3, 0*N+7, 0*N+12,
1*N+3, 1*N+7, 1*N+12,
2*N+3, ...]
→ 长度 = 2L × 3,一次 NIXL READ 把这 2L×3 个 (addr,len) 对打包传输
→ 等价于"L 层 × 3 个 block × K+V",所有层一次 RDMA 完成