0. 前言

nixl 的架构见0. nixl research,这里主要说 vllm 使用 nixl connector完成 kv cache 传输的过程。

结构

nixl connector
Decode (NixlConnectorWorker)                      Prefill 侧 (NixlConnectorWorker)
══════════════════════════════════════════════════   ══════════════════════════════════════════════
 
[启动阶段] register_kv_caches(kv_caches)
────────────────────────────────────────────────────────────────────────────────────────────────
两侧各自执行,互不感知:
 
1. 遍历所有层 kv_caches (layer_name -> Tensor):
   - split_k_and_v=True(FA): K/V 分两个 region
   - split_k_and_v=False(FlashInfer): K+V 合一个 region,后再 num_regions *=2
   - 记录每层: base_addr, block_len_per_layer[i]
 
2. nixl_wrapper.register_memory(
     [(base_addr, tensor_bytes, device_id, ""), ...],
     memory_type="VRAM"
   )
   └─ 全部层的 KV tensor 一次性注册到 NIXL,变成 RDMA 可见内存
 
3. register_local_xfer_handler():
   blocks_data = []
   for i, base_addr in seen_base_addresses:        ← 每个 region (每层 or K/V)
     for block_id in range(num_blocks): 每个 block
       addr = base_addr + block_id * block_len_per_layer[i]
       blocks_data.append((addr, kv_block_len, device_id))
   descs = nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
   src_xfer_side_handle = nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
   └─ src dlist: num_regions × num_blocks desc,扁平列表
      desc_index = region_i * num_blocks + block_id
 
4. 构造 NixlAgentMetadata:
   { engine_id, agent_metadata(RDMA连接信息),
     kv_caches_base_addr[0..num_regions-1],
     num_blocks, block_lens[0..num_regions-1],
     kv_cache_layout, block_size }
   封装为 NixlHandshakePayload { compat_hash, agent_metadata_bytes }
 xfer_handshake_metadata (等待握手时对外暴露)
 
5. (Prefill ) Scheduler._nixl_handshake_listener_t 启动
   ZMQ ROUTER bind: side_channel_host:side_channel_port+dp_rank
   └─ 等待 Decode 侧来拉 agent metadata
 
[Prefill 完成推理,请求结束] request_finished()
────────────────────────────────────────────────────────────────────────────────────────────────
                                                     Prefill Scheduler:
                                                     request_finished(request, block_ids)
                                                     └─ params["do_remote_decode"]=True
                                                         delay_free_blocks = True
                                                         _reqs_need_send[req_id] = now + timeout
                                                         return (True, {
                                                           do_remote_prefill: True,
                                                           remote_block_ids: block_ids, Prefill block
                                                           remote_engine_id: self.engine_id,
                                                           remote_request_id: req_id,
                                                           remote_host: side_channel_host,
                                                           remote_port: side_channel_port,
                                                           tp_size: tp_size,
                                                         })
                                                     └─ 这包 kv_transfer_params HTTP 响应
                                                        发回给 Decode 侧请求发起方
 
[Decode 侧收到 kv_transfer_params,请求入队]
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Scheduler:
get_num_new_matched_tokens():
  params["do_remote_prefill"]=True
 返回 (len(prompt_tokens) - num_computed_tokens, async=True)
 
update_state_after_alloc():
  params["do_remote_prefill"]=True + remote_block_ids 存在
 _reqs_need_recv[req_id] = (request, local_block_ids)
  params["do_remote_prefill"] = False 防止重复触发
 
build_connector_meta():
  meta.reqs_to_recv[req_id] = ReqMeta(
    local_block_ids,
    remote=RemoteMeta(
      block_ids=remote_block_ids, Prefill 侧的 block
      engine_id=remote_engine_id,
      request_id=remote_request_id,
      host=remote_host,
      port=remote_port,
    )
  )
  meta.reqs_to_send = {req_id: expiry} Prefill 侧需要知道 Decode 拿走了
 
[每个 forward step] start_load_kv(metadata)
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker:
for req_id, meta in metadata.reqs_to_recv.items():
  meta.local_physical_block_ids = _logical_to_kernel_block_ids(meta.local_block_ids)
  meta.remote.block_ids          = _logical_to_kernel_block_ids(meta.remote.block_ids)
  _recving_metadata[req_id] = meta
 
  if remote_engine_id NOT in _remote_agents:
    _background_nixl_handshake(req_id, remote_engine_id, meta)
    └─ ThreadPoolExecutor 提交 _nixl_handshake(), 继续
  else:
    _read_blocks_for_req(req_id, meta)   ← 握手已完成,直接发
 
while not _ready_requests.empty():
  _read_blocks_for_req(*_ready_requests.get_nowait())  ← 处理本轮握手刚完成的请求
 
[握手子流程] _nixl_handshake() (在后台线程)
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker (后台线程):                            Prefill Scheduler (_nixl_handshake_listener_t):
 
p_remote_rank = kv_topo.get_target_remote_rank(remote_tp_size)
ZMQ REQ connect tcp://remote_host:remote_port
sock.setsockopt(RCVTIMEO, 5000ms)
msg = encode((GET_META_MSG, p_remote_rank))
sock.send(msg)                          ──────────→  ROUTER recv_multipart()
                                                     decode msg (GET_META_MSG, target_tp_rank)
                                                     sock.send_multipart([identity, "",
                                                       encoded_data[target_tp_rank]])
                                                     └─ encoded_data = NixlHandshakePayload {
                                                          compat_hash,
                                                          agent_metadata_bytes: NixlAgentMetadata {
                                                            kv_caches_base_addr, 对端所有层 base addr
                                                            num_blocks,
                                                            block_lens, 每层 block_len
                                                            device_id,
                                                            kv_cache_layout,
                                                            block_size
                                                          }
                                                        }
handshake_bytes = sock.recv()  ←────────────────────┘
 
验证 compat_hash
decode NixlAgentMetadata
 
add_remote_agent(metadata, p_remote_rank, remote_tp_size):
  nixl_wrapper.add_remote_agent(metadata.agent_metadata) → remote_agent_name
 
  blocks_data = []
  for i, remote_base_addr in nixl_meta.kv_caches_base_addr:   ← 每个 region (Prefill 侧的)
    kv_block_len = get_backend_aware_kv_block_len(i) / tp_ratio
    rank_offset  = (tp_rank % tp_ratio) * kv_block_len         ← 异构TP:切 kv_head 维
    for block_id in range(nixl_meta.num_blocks): Prefill 侧每个 block
      addr = remote_base_addr + block_id * block_lens[i] + rank_offset
      blocks_data.append((addr, kv_block_len, remote_device_id))
 
  descs = nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
  dst_xfer_side_handles[engine_id] =
    nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
  └─ dst dlist: num_regions × remote_num_blocks desc
     desc_index = region_i * remote_num_blocks + block_id
 
完成后 _remote_agents[engine_id][rank] = remote_agent_name
 Future done_callback (req_id, meta) put 进 _ready_requests
 
[RDMA 传输核心] _read_blocks(local_block_ids, remote_block_ids, dst_engine_id, ...)
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker (在 start_load_kv 调用链上):
 
# 如果全量 prefix cache 命中 (0 块需要传)
if len(local_block_ids) == 0:
  nixl_wrapper.send_notif(remote_agent_name, notif_msg=f"{remote_req_id}:{tp_ratio}")
  return 只通知 Prefill 可以释放 blocks,不做实际 RDMA
 
# 对齐 block 数量(partial prefix hit 只传尾部未命中的块)
remote_block_ids = remote_block_ids[-len(local_block_ids):]
 
# 计算 desc index 列表 (一次覆盖所有层)
_get_block_descs_ids(engine_id, block_ids, layer_idx=None):
  region_ids = arange(num_regions)          # [0, 1, ..., num_regions-1]
  descs_ids  = region_ids[:, None] * num_blocks + block_ids[None, :]
  return descs_ids.flatten()
  └─ 长度 = num_regions × len(block_ids)
     = (num_layers × 1 or 2) × num_blocks_to_transfer
     每个元素是 dlist 里的索引,指向一个具体的 (addr, len) desc
 
local_descs_ids  = _get_block_descs_ids(self.engine_id,  local_block_ids)
remote_descs_ids = _get_block_descs_ids(dst_engine_id,   remote_block_ids)
 
# 一次 NIXL RDMA READ 把所有层指定 blocks 全部传完
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
handle = nixl_wrapper.make_prepped_xfer(
  "READ",
  src_xfer_side_handle, Decode 本地 dlist
  local_descs_ids, 本地哪些 desc (目标写入位置)
  dst_xfer_side_handle, Prefill 远端 dlist
  remote_descs_ids, 远端哪些 desc (数据来源)
  notif_msg=notif_id RDMA 完成后自动发 notification Prefill
)
nixl_wrapper.transfer(handle)  ← 异步发起,立即返回
_recving_transfers[req_id].append(handle)
 
注意:notif_msg NIXL 内置的 RDMA 完成通知机制,
      RDMA 数据落地后 Prefill nixl_wrapper.get_new_notifs() 会收到
      "{remote_request_id}:{tp_ratio}" 字符串
 
[forward 期间] wait_for_layer_load() / save_kv_layer()
────────────────────────────────────────────────────────────────────────────────────────────────
Decode 侧: pass (NIXL 不做 per-layer 操作,transfer 已在 start_load_kv 全量异步发出)
Prefill 侧: pass (KV 已在 register_kv_caches 时注册为 RDMA 可见,无需逐层 save)
 
[Prefill: host_buffer 模式下 wait_for_save()]
────────────────────────────────────────────────────────────────────────────────────────────────
                                                     if use_host_buffer and copy_blocks:
                                                       save_kv_to_host(metadata)
                                                       └─ for req in reqs_to_save:
                                                            copy_blocks(device_kv host_xfer_buffers,
                                                              block_ids, "d2h")  ← 阻塞 D2H
                                                     (TPU 等不支持直接 NIXL 注册的设备走此路径)
 
[完成轮询] get_finished() — 每个 forward step 调用
────────────────────────────────────────────────────────────────────────────────────────────────
Decode Worker:                                       Prefill Worker:
 
_pop_done_transfers(_recving_transfers):             _get_new_notifs():
  for req_id, handles in _recving_transfers:           for notifs in nixl_wrapper.get_new_notifs():
    for handle in handles:                               req_id, tp_ratio = notif.split(":")
      state = nixl_wrapper.check_xfer_state(handle)     consumer_notification_counts[req_id] += 1
      if state == "DONE":                                if count == tp_ratio: 所有 D worker 都拉完
        nixl_wrapper.get_xfer_telemetry(handle)            notified_req_ids.add(req_id)
        nixl_wrapper.release_xfer_handle(handle)           del _reqs_to_send[req_id]
        done_recving.add(req_id)                           └─ Prefill 可以释放这些 KV blocks
      elif state == "PROC":
        in_progress = True
      else:  ← 失败
        _handle_failed_transfer(req_id, handle)
 
done_recving 返回后:
  for req_id in done_recving:
    meta = _recving_metadata.pop(req_id)
    if use_host_buffer:
      sync_recved_kv_to_device(req_id, meta)  ← H2D copy
    if enable_permute_local_kv:
      permute_device_kv(meta.local_physical_block_ids)
    if block_size_ratio > 1:
      blocksize_post_process(...)  ← 异构 block size 重排
 
[超时兜底] get_finished() 末尾
────────────────────────────────────────────────────────────────────────────────────────────────
                                                     Prefill Worker:
                                                     while _reqs_to_send:
                                                       req_id, expires = oldest entry
                                                       if now < expires: break
 超过 VLLM_NIXL_ABORT_REQUEST_TIMEOUT
                                                       强制释放,done_sending.add(req_id)
                                                       ( warning log)
 
[关键数据结构总结]
────────────────────────────────────────────────────────────────────────────────────────────────
src/dst xfer_side_handle desc 布局 (以 FA 为例, num_layers=L, num_blocks=N):
 
  dlist index:   0          1     ...  N-1    N         N+1   ... 2N-1   2N   ...  2L*N-1
                [layer0-K                   ] [layer0-V                ] [layer1-K  ...  ]
                 block0  block1  ... blockN-1  block0  block1 ... blockN-1  ...
 
  _get_block_descs_ids(engine, [b3, b7, b12]):
    region_ids = [0, 1, 2, 3, ..., 2L-1]
    descs = [0*N+3, 0*N+7, 0*N+12,
             1*N+3, 1*N+7, 1*N+12,
             2*N+3, ...]
 长度 = 2L × 3,一次 NIXL READ 把这 2L×3 (addr,len) 对打包传输
 等价于"L 层 × 3 个 block × K+V",所有层一次 RDMA 完成