1. conclusion

2. specific

从找到 libnccl.so开始,一直到调用其接口完成 kv cache 分为 4 层。 1.找到通信库 so 2.用 python 封装一层 3.建链+传输组件 4.KV cache 连接器

2.1 第一层:找nccl的 so

  • 指定了 VLLM_NCCL_SO_PATH 环境变量的话,vllm 就回去找我们指定的、
  • 否则,看当前是 CUDA/HIP(ROCm) 选对应的通信库。
find nccl so
def find_nccl_library() -> str:
    """Return NCCL/RCCL shared library name to load.
 
    Uses `VLLM_NCCL_SO_PATH` if set; otherwise chooses by torch backend.
    """
    so_file = envs.VLLM_NCCL_SO_PATH
    if so_file:
        logger.info(
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
        )
    else:
        if torch.version.cuda is not None:
            so_file = "libnccl.so.2"
        elif torch.version.hip is not None:
            so_file = "librccl.so.1"
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
        logger.debug_once("Found nccl from library %s", so_file)
    return so_file

2.2 第二层:Python 封装 NCCL C API (pynccl_wrapper.py)

PyNcclCommunicator 类的成员 nccl ,通过 self.nccl = NCCLLibrary(library_path) 实例化 NCCLibrary。

在NCCLibrary的实现里面做了两件事情:

a)用 ctypes.CDLL 去打开了前面 nccl 的 so 文件

加载动态库
def __init__(self, so_file: str | None = None):
    so_file = so_file or find_nccl_library()
 
    try:
        if so_file not in NCCLLibrary.path_to_dict_mapping:
            lib = ctypes.CDLL(so_file)
            NCCLLibrary.path_to_library_cache[so_file] = lib
        self.lib = NCCLLibrary.path_to_library_cache[so_file]
    except Exception as e:
        logger.error(
            "Failed to load NCCL library from %s. "
            // ... error message ...
        )
        raise e

b) 签名表定义了 pythonso内 c api 的返回参数 和 传进去什么格式值,ctypes会去管理传下去传上来。

绑定所有导出函数的签名
if so_file not in NCCLLibrary.path_to_dict_mapping:
    _funcs: dict[str, Any] = {}
    for func in NCCLLibrary.exported_functions:
        try:
            f = getattr(self.lib, func.name)
            f.restype = func.restype
            f.argtypes = func.argtypes
            _funcs[func.name] = f
        // ...
    NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
    self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]

2.3 第三层:P2P NCCL 引擎

1)初始化

  1. 实例化 NCCLLibrary → 触发第一层 + 第二层的 so 加载和绑定
  2. 启动 ZMQ ROUTER socket(用于带外信令,不是数据传输)
  3. 创建 CUDA stream 用于 send/recv(self.send_streamself.recv_stream
  4. 启动后台线程 listen_for_requests(监听来自远端的连接/传输请求)
初始化p2p nccl engine
class P2pNcclEngine:
    def __init__(
        self,
        local_rank: int,
        config: KVTransferConfig,
        hostname: str = "",
        port_offset: int = 0,
        library_path: str | None = None,
    ) -> None:
        // ...
        self.nccl = NCCLLibrary(library_path)
        self.context = zmq.Context()
        self.router_socket = self.context.socket(zmq.ROUTER)
        self.router_socket.bind(f"tcp://{self.zmq_address}")
        self.poller = zmq.Poller()
        self.poller.register(self.router_socket, zmq.POLLIN)
        self.send_stream = torch.cuda.Stream()
        self.recv_stream = torch.cuda.Stream()
        self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
        // ...

2) 建立 NCCL Communicator 两个 vllm 实例之间走 ZMQ 来交换 uniqueID,然后各自 ncclCommInitRank 建立点对点 comm

发起方:生成 uniqueI-ZMQ(socket) 发送给远端—自己以 rank=0, world_size=2 调 ncclCommInitRank 建立通信器。

sender
    def create_connect(self, remote_address: str | None = None):
        // ...
            unique_id = self.nccl.ncclGetUniqueId()
            data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
            sock.send(msgpack.dumps(data))
 
            with torch.accelerator.device_index(self.device.index):
                rank = 0
                with set_p2p_nccl_context(self.nccl_num_channels):
                    comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank)
                self.comms[remote_address] = (comm, rank)

接收方: 从 ZMQ 收 uniqueID的 128 字节数据,然后还原成 ncclUniqueId 结构体,自己再以 rank=1 去调用ncclCommInitRank 建立通信器。

receiver
    def listen_for_requests(self):
        while True:
            // ...
            if data["cmd"] == "NEW":
                unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"]))
                with torch.accelerator.device_index(self.device.index):
                    rank = 1
                    with set_p2p_nccl_context(self.nccl_num_channels):
                        comm: ncclComm_t = self.nccl.ncclCommInitRank(
                            2, unique_id, rank
                        )
                    self.comms[remote_address.decode()] = (comm, rank)

3)实际 send/recv 传输

流程
Prefill (send_sync)                         Decode (listen_for_requests)
─────────────────                         ──────────────────────────
                                          
1. save_kv_layer() 遍历每个 request、
   每个 attention layer,提取 KV:
   kv_cache = kv_layer[block_ids, ...]
                                          
2. send_tensor() → 放入 send_queue        
   (PUT_ASYNC 模式异步)                   
                                          
3. send_async thread 取出 item            
   → send_sync():                         
                                          
4. ZMQ 发送 PUT 元数据                     
   {cmd:PUT, tensor_id, shape, dtype}  ──→ 5. 收到 PUT 命令
                                          
                                           6. torch.empty(shape, device=GPU)  ← 分配临时 tensor!
                                          
                                           7. 回复 b"0" ──→  8. 收到 ACK
                                          
5. NCCL send(tensor, dst=1)           ←→  10. NCCL recv(tensor, src=0)
   (双边通信,两端同步)                      直接 recv 到步骤6的临时 tensor
                                          
                                           11. 检查阈值,可能 spill 到 CPU
                                           12. 存入 recv_store[tensor_id]
                                          
                                          ... 后续在 start_load_kv() 中 ...
                                          
                                           13. tensor = recv_store.pop(tensor_id)
                                           14. kv_cache_layer[block_ids] = tensor  ← copy 到真正的 KV cache

这里比较坑的就是接收侧,在 recv 之前的第一步在 empty 分配临时 tensor,再发 ACK 给 prefill 端告诉发段可以发,然后我之前遇到的所有报错都源自于这里。

D 端的 listen 线程代码
def listen_for_requests(self):
    // ...
    elif data["cmd"] == "PUT":
        tensor_id = data["tensor_id"]
        try:
            with torch.cuda.stream(self.recv_stream):
                tensor = torch.empty(
                    data["shape"],
                    dtype=getattr(torch, data["dtype"]),
                    device=self.device,
                )
            # 发ack
            self.router_socket.send_multipart([remote_address, b"0"]) 
            // ...
            if self.buffer_size + tensor_size > self.buffer_size_threshold:
                # Store Tensor in memory pool(这个更坑,会拷贝到 cpu 的 pin memory)
                addr = self.pool.store_tensor(tensor)
                tensor = (addr, tensor.dtype, tensor.shape)
                logger.warning( "🔴[PUT]Recv Tensor, Out Of Threshold, "........)
            else:
                self.buffer_size += tensor_size
 
        except torch.cuda.OutOfMemoryError:
            self.router_socket.send_multipart([remote_address, b"1"])
            tensor = None
            logger.warning("🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s"........)
 
        with self.recv_store_cv:
            self.recv_store[tensor_id] = tensor
            self.have_received_tensor_id(tensor_id)
            self.recv_store_cv.notify()

2.4 KV cache connector

负责对接上层 vllm 的业务逻辑,在 attention 层提取和传入 kv cache,分为 producer 和 consumer。

1) prefill 实例的 producer 端 — save_kv_layer()

对每一层 attention 的 KVcache 根据 block_ids 从 paged kv buffer 中抽取对应的 kv tensor,然后以 request_id#layer_name 为 tensor_id,调用引擎的 send_tensor 发出去。

extract_kv_from_layer()
def extract_kv_from_layer(
    layer: torch.Tensor,
    block_ids: torch.Tensor,
) -> torch.Tensor:
    if layer.ndim == 3 or layer.shape[1] == 2:  # MLA or FlashInfer
        return layer[block_ids, ...]
    if layer.shape[0] == 2:  # FlashAttention
        return layer[:, block_ids, ...]
    return None
 
// ...
for request in connector_metadata.requests:
    // ...
    kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
    // ...
    self.p2p_nccl_engine.send_tensor(
        request_id + "#" + layer_name, kv_cache, remote_address
    )

2)Consumer 端(Decode 实例)— start_load_kv()

对每个请求的每一层都走引擎的 recv_tensor 获取 KV tensor,用 inject_kv_into_layer 按 block_ids 写入本地 paged KV buffer

for request in metadata.requests:
    // ...
    for layer_name in forward_context.no_compile_layers:
        // ...
        kv_cache = self.p2p_nccl_engine.recv_tensor(
            request.request_id + "#" + layer_name, remote_address
        )
        // ...
        inject_kv_into_layer(
            layer, kv_cache, request.block_ids, request.request_id
        )