0. introduction

在 vllm 传输 kv cache 的过程中,主要包括 mooncake,nixl,nccl…。其中 nccl connector 的实现过于丑陋,实际推理服务无法稳定/高性能使用。原因参考:0. vllm 如何使用 nccl 传输 kv cache 1. Bugfix for vllm deepseek v3.2 1p1d 2. vllm use vllm-plugin-fl、flagGemms and flagcx run Deepseek v3.2

原本的双边需要 receiver 提前 empty 一块内存来让 sender 有 dst_addr,在高并发/系统过载的推理服务场景,这会进一步和 kv cache 抢占 VRAM/DRAM。因此需要一个类似 mooncake/nixl transfer engine 的单边调用 ibv_post_send 来完成 kv cache 传输,即:nccl/flagcx 内就是 putSignal 的实现。

1. design

主要分为 3 部分:

  • 在 tp group 内给所有 P 到 D 的 rank来创建 FlagcxHeteroComm,在增加 P/D 的动态扩缩容的情况下就不用全局所有节点重启。
  • 给 request 重排成 task/ 切slice,类似 mooncake。
  • 在 pair 内用 flagcxOneSideRegister 完成基础的 rkey,MR,gid 的注册,在实际传输的时候由 receiver 去用写 ZMQ 告诉 sender 们我需要哪些 kvcache block(顺便把 dst_addr+rkey+gid带上),在 sender 们拿到后直接下单边的 putSignal,receiver 根据需要接收的 slice 数量去 waitValue。

1.1 已知问题

a. 解决不同comm无法注册的问题

首先当前的注册存在一个问题就是:prefill 的 rank0 会跟 comm1 的 rank0 注册,也跟 comm2 的 rank1 注册。然后因为 flagcx 内注册发现是自己(commrank)直接返回了就不注册了,rank0 的注册完的 key/Mr/gid 啥的就都不能给 deocde 的 rank1 用。 06a6d8cf227e913de139e9c1f81b4cff.png 于是,在 flagcx 内增加 pr 修复,来支持 asymmetric 的 PD disaggregation TP 配置。0. flagcxOneSideRegister 支持 comm 隔离

b. initRank + OneSideRegister

给不同 tp/pp/ep下的 kv cache的 sender/receiver 建立 comm 以便后续能够注册, 只开 tp 为 common case,这里不讨论,ep 也不会影响 attention 计算后的 kv。目前已知的问题主要是开 pp 时mooncake/nixl存在端口冲突问题。

mooncake 开 PP2 TP2 DP2 则会出现以下情况:

简单来说,mooncake connector 的每个 gpu 使用 ip:port 是通过 base_port(这里假设 9998)+ tp rank 得到。当开了 pp 之后,这里存在不同 pp stage 的 tp0 使用相同的 port。

nixl 开 pp2 tp2 dp2 的时候情况如下: nixl 的 zmq 监听端口计算方式:listen_port = base + vllm_config.parallel_config.data_parallel_rank

Workertp_rank值KV cache 内容
PP stage 0, TP rank 00第 0~N/2 层
PP stage 0, TP rank 11第 0~N/2 层
PP stage 1, TP rank 00 (冲突!)第 N/2~N 层
PP stage 1, TP rank 11 (冲突!)第 N/2~N 层
  • tp_rank 在不同 PP stage 之间是重复的(都是 0 和 1),最终 vllm/v1/engine/core.py 的 content.update() 让后一个 PP stage 的 metadata 把前一个覆盖掉

nixl:https://github.com/vllm-project/vllm/issues/30501 mooncake:https://github.com/vllm-project/vllm-ascend/issues/4244

comm 创建核心待解决问题就是:

当出现 PP 切分了 layer 之后,不同 pp stage 内相同 tp_rank index 需要有隔离的能够传递 metadata 的通道来通知 sender/receiver各自下一个 flagcxCommInitRank

1.2 最终方案

综上,我们的 initRank 和注册逻辑设计为:

init flagcx comm logic
Decode 侧                                            Prefill 侧
──────────────────────────────────────────────────   ──────────────────────────────────────────────
 
[启动阶段] register_kv_caches()
_decode_listener_thread 起来                          _sender_thread 起来
  ROUTER bind: host:port+tp_rank                      ROUTER(frontend) bind: host:port+tp_rank
  等待 "NEW" 握手                                     PULL (backend inproc) 等 worker 回执
                                                     _sender_executor 线程池就绪
 
 
[请求到达] start_load_kv(metadata)
按 (remote_host, remote_port+tp_rank) 分组
每个 path 起一个 _receive_kv 协程
挂 PendingSignalWait 到 _active_signal_waits
 
 
_receive_kv(path, req_blocks, pending_wait):
1. 构造 FlagCXAgentMetadata
   (req_ids, block_ids, base_addr, hostname, port)
2. 创建 REQ socket  (connect → path)
   setsockopt RCVTIMEO = 120s
3. await sock.send(encoded_data)          ───────→  frontend(ROUTER) recv_multipart
                                                    ├─ identity, _, metadata_bytes
                                                    └─ _sender_executor.submit(_sender_worker)
 
                                                    _sender_worker(identity, metadata_bytes, backend_path)
                                                    ├─ decode metadata
                                                    ├─ decode_listener_addr = host:port+tp_rank
                                                    ├─ 若 pair_comms[decode_listener_addr] 不存在:
                                                    │   _create_pair_comm(decode_listener_addr)  ← Prefill 主动发起
                                                    │   ├─ uid = flagcxGetUniqueId()
                                                    │   ├─ DEALER connect → Decode ROUTER
                                                    │   ├─ send({cmd:"NEW", uid:uid_bytes})      ──┐
                                                    │   ├─ flagcxCommInitRank(2, uid, rank=0) ←┐  │
                                                    │   ├─ _register_kv_for_comm(comm)         │  │
                                                    │   │   ├─ flagcxOneSideRegister KV MRs    │  │
                                                    │   │   └─ flagcxOneSideSignalRegister     │  │
                                                    │   ├─ sock.recv()  ← 等 Decode "OK"       │  │
                                                    │   └─ pair_comms[addr]=PairCommInfo       │  │
                                                    │        (comm, my_rank=0, signal_buf,     │  │
                                                    │         signal_counter=0, send_lock)     │  │
                                                    │                                          │  │
                    _decode_listener_thread:  ←─────┼──────────────────────────────────────────┼──┘
                    recv (identity, "NEW", uid) ────┤
                    ├─ flagcxCommInitRank(2, uid,rank=1) ← rendezvous ─────────────────────────┘
                    ├─ _register_kv_for_comm(comm)
                    │   ├─ flagcxOneSideRegister KV MRs
                    │   └─ flagcxOneSideSignalRegister
                    ├─ pair_comms[remote_addr]=PairCommInfo(comm, my_rank=1,...)
                    └─ router.send_multipart([identity, "OK"])  ── 回给 Prefill DEALER
 
                                                    ├─ expected_signal = _send_kv_to_decode(meta)
                                                    │   ├─ 轮询 reqs_need_send 最多 30s
                                                    │   │   等 start_load_kv 把 local_block_ids 填好
                                                    │   │   + send_meta.ready.set()
                                                    │   └─ _send_blocks(send_reqs, meta)
                                                    │       ├─ 摊平为 xfer_list
                                                    │       │   (per layer × per contiguous group × per req)
                                                    │       ├─ with pair_info.send_lock:
                                                    │       │   pair_info.signal_counter += 1
                                                    │       │   expected_signal = signal_counter
                                                    │       │   for i, xfer in enumerate(xfer_list):
                                                    │       │     sig_val = 1 if i==last else 0
                                                    │       │     flagcxPutSignal(..., sig_val)
                                                    │       │     └─ 远端 signal_buffer 原子 +sig_val
                                                    │       └─ return expected_signal
                                                    ├─ finished_sending_reqs += req_ids
                                                    └─ reply = str(expected_signal).encode()


                                                     PUSH(inproc) → backend(PULL)
                                                     frontend.send_multipart([identity, "", reply])

4. reply = await sock.recv()        ←──────────────────┘
   (payload = b"123" 这样的整数;出错回 b"ERR")
 
5. expected_signal = int(reply)
6. pair_info = pair_comms 里唯一那条(由 listener 握手阶段建好)
   pending_wait.comm        = pair_info.comm
   pending_wait.peer_rank   = 1 - pair_info.my_rank
   pending_wait.signal_value = expected_signal
7. pending_wait.ready.set()
8. finished_recving_reqs.update(req_ids)
 
 
[forward 推进] wait_for_layer_load()
  waits = _active_signal_waits.pop_all()
  for w in waits: w.ready.wait(60)                 ← 等所有 ZMQ REP 都回来
  valid = [w for w in waits if w.comm and w.signal_value>0]
  max_wait = max(valid, key=signal_value)          ← counter 单调,等最大值即覆盖其余
  flagcx_stream = adaptor_stream_copy(current_stream())
  flagcxWaitSignal(max_wait.comm, max_wait.peer_rank,
                   sigOff=0, max_wait.signal_value, flagcx_stream)
  └─ GPU 侧 poll signal_buffer ≥ expected,RDMA 落地后继续 forward
 

2 bugfix

  • debug flagcx connector ✅ 2026-04-16
  • 【bug1】看到flagcxWaitSignal 报错raise RuntimeError(f”FLAGCX error: {error_str}”)

    • 确定了 error 码是 1,就是Unhandled device error ⬇️
    • 先增加了 comm 初始化 并发的隔离,未解决,但是代码保留 ⬇️
    • 然后怀疑是 context 用了其他的 cudaDevice 导致的,加了setCurrentDeive之后导致会出现 hang,不会直接报错。问题不在这,因为打印发现 worker 线程都是这样用的,并且和 mooncake 使用的方式是对齐的,问题不在这。❌
    • sanitizer 排查内存问题,
      • sanitizer 跨机给出的 error 分别有:cudaErrorNoKernelImageForDevice (error 209) on cudaGetLastError, CUDA_ERROR_NOT_PERMITTED (error 800) on cuMemCreate,CUDA_ERROR_NOT_SUPPORTED (error 801) on cuMemGetHandleForAddressRange. 三类错误分别是没指定sm90,VMM 不知道为什么分配被拒,不支持ncclCommWindowRegister。❌
      • 发现flagcxOneSideBuildFullMesh内存在下面问题,需要结合 flagcx.cc内flagcxOneSideBuildFullMesh的 for 循环来看:
        时间轴 ──────────────────────────────────────────────────────►
      
        rank0:  [i=0: self↔self] ──完成──► [i=1: while循环开始]
                                             └─ connect(rank1.listen)   TCP 进入 rank1 的 accept queue
                                                accept(rank0.listen)    等待 rank1 来连
      
        rank1:  [i=0: while循环]
                 connect(rank1.self)   TCP 进入 rank1 自己的 accept queue
                 accept(rank1.listen)  ← rank1的accept queue里现在有2个连接:
                                            [A] rank1 自己(self-connect, i=0预期)
                                            [B] rank0 发来的 (i=1来的, 不该这轮消费)
      
        OS 的 accept queue 是 FIFO,但 [A] 和 [B] 谁先到是竞态。如果 rank0 的 TCP connect 先到:
      
        rank1 的 accept() 拿到了 [B] (rank0 的 i=1 连接)
          → recvComm 设置为来自 rank0 的连接 ✓ (recvComm != NULL)
          → while 条件: sendComm==NULL || recvComm==NULL
                       →  recvComm 非 NULL,accept() 不再被调用
      
        rank1 的 connect(self) 还卡在 StateSend/StateConnecting
          → 需要有人 accept() 自己发过去的 QP info
          → 但 while 循环里 recvComm 已非 NULL,不会再调 accept()
          → sendComm 永远是 NULL
          → while(sendComm==NULL || recvComm==NULL) 永远成立
          → rank1 无限循环,sendComm 卡死 ← HANG
      
      问题是,这里在和 mc 讨论中发现 mpich 启动的测试从来不会出现问题,我用 openmpi 会出现问题。就算这里加上 barrier,在 pd 分离的时候依旧会出现 decode 出一样的报错。❌
    • 两边 signal 计数会乱,改成发端统计好有多少signal,通知收端,收端直接只下一次 wait 操作。 ⬇️
    • 在flagcxHeteroWaitSignal内加了D2H的拷贝,把当前收端 current signal 打印出来和 sender 发过来的signal 做对比, 观察到每次连续 runtest的时候第二次测试 recv 端需要 194 个,但是每次卡在 180个左右就再也等不到了,prefill 侧 log 看到 ibrc 打印FLAGCX WARN NET/IB : unable to allocate requests和FLAGCX WARN flagcxRmaProgressThread: op failed peer=1 type=1 res=3 当rma 的 proxyThread 执行的时候,flagcxRmaProgressThread 去调用 ibrc_adaptoe封装的flagcxIbIputSignal,这里当并发大的时候就会返回 flagcxInternalError,然后 flagcxRmaProgressThread 检查不是 flagcxSuccess 就会直接 free 这个 wr。。。。。
        flagcxRmaProgressThread    ← 从 pending 取出 desc
            └─ netAdaptor->iputSignal()   ← IB adaptor 层
                └─ flagcxIbGetRequest()   ← 从 reqs[256] 里找一个 UNUSED slot
                    ← 找到 → 设 type=IPUT, 返回指针存到 desc->request
                    ← 找不到 → "unable to allocate requests", 返回 flagcxInternalError
            └─ 成功后: desc 挂到 inProgress 链表
            └─ 后续循环: netAdaptor->test(desc->request)
                └─ flagcxIbTest → flagcxIbCommonTestDataQp
                    └─ ibv_poll_cq() 收割 CQE
                    └─ events 减到 0 → flagcxIbFreeRequest(r)  ← r->type = UNUSED, slot 回收
                └─ done=1 → rmaDescComplete(desc) → free(desc)
      
      所以问题就是如果256 个 slots 都不是空的时刻,新进来的请求就会被 flagcxIbGetRequest 函数内的for loop 挡在外面,返回的是flagcxInternalError 导致后续的代码直接 free 掉了当前的 desc,因此在这里尝试了不 free,不成功就把当前的 desc 重新enque 到 proxy 链表后,所有

    if (res != flagcxSuccess) { WARN(“flagcxRmaProgressThread: op failed peer=%d type=%d res=%d”, p, (int)desctype, (int)res); __atomic_store_n(&proxyrmaError, 1, __ATOMIC_RELEASE); free(desc); ```

  • 【bug2】tp=2 的 prefill / decode观察到,出现 hang 的时候每一边都另一个 host:ip 服务看不到

    Prefill:flagcx_connector.py:554: Pair comm ready (responder/rank=1) ↔ 10.8.2.169:8999
    Decode:flagcx_connector.py:575: Pair comm ready (initiator/rank=0) ↔ tcp://10.8.2.168:8999
    Prefill:
    (Worker_TP0 pid=167763) [2026-04-15 19:59:53] INFO flagcx_connector.py:510: Registered 96 KV MRs + per-pair signal buffer for pair comm=0x7fb1d4001160 (signal_ptr=0x7fc0315eb400, signal_device=cuda:0, current_device=0)
    (Worker_TP0 pid=167763) [2026-04-15 19:59:53] INFO flagcx_connector.py:554: Pair comm ready (responder/rank=1) ↔ 10.8.2.169:8998
    (Worker_TP1 pid=167764) [2026-04-15 19:59:53] INFO flagcx_connector.py:510: Registered 96 KV MRs + per-pair signal buffer for pair comm=0x7f8c88001160 (signal_ptr=0x7f8c60200000, signal_device=cuda:0, current_device=0)
    (Worker_TP1 pid=167764) [2026-04-15 19:59:53] INFO flagcx_connector.py:554: Pair comm ready (responder/rank=1) ↔ 10.8.2.169:8999
     
    Decode:
    INFO flagcx_connector.py:510: Registered 96 KV MRs + per-pair signal buffer for pair comm=0x7fb4e4000ba0 (signal_ptr=0x7fc3515eb000, signal_device=cuda:0, current_device=0)
    (Worker_TP0 pid=54913) [2026-04-15 19:59:53] INFO flagcx_connector.py:575: Pair comm ready (initiator/rank=0) ↔ tcp://10.8.2.168:8998
    (Worker_TP1 pid=54914) [2026-04-15 19:59:53] INFO flagcx_connector.py:510: Registered 96 KV MRs + per-pair signal buffer for pair comm=0x7f48f8000ba0 (signal_ptr=0x7f48d6200000, signal_device=cuda:0, current_device=0)
    (Worker_TP1 pid=54914) [2026-04-15 19:59:53] INFO flagcx_connector.py:575: Pair comm ready (initiator/rank=0) ↔ tcp://10.8.2.168:8999

    这里的改动思想很简单,就是 sender 的 work 第一次进来去告诉receiver 我的uid,同时 decode 的 listen 线程提前开始等待这个,一起开始调用commInitRank和_register_kv_for_comm。解决这个 bug2 后在后续测试 bug1 的几十次 vllm serve 都没有出现开头就 hang 的问题了。✅

  • 【bug3】优化 flagcxRmaProgressThread性能,现在存在请求多大 size 时性能巨差

3 test

168/169机器需要让 gpu 内核加载这个 ib 的模块(如果没有的话):

lsmod | grep -i peer
modprobe nvidia_peermem

安装 flagos 的东西:

git clone https://github.com/flagos-ai/vllm-plugin-FL.git
pip install --no-build-isolation -e .
 
git clone https://github.com/flagos-ai/FlagGems.git
pip install --no-build-isolation -e .

剩下的参考:vllm-plugin-FL/examples/disaggregated_serving_xpyd/run_flagcx_connector.md