0 前言
1. nvshmemi_ibgda_put_nbi_warp
函数签名
template <bool kAlwaysDoPostSend = false>
__device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx);- req_rptr ← dst_ptr(上面计算的远端地址,即接收端缓冲区的绝对地址)。
- req_lptr ← src_ptr 或 rdma_x_src_idx / buf_ptr(源数据在本地内存/发送缓冲区的地址)。
- bytes ← num_bytes_per_msg(单条消息的字节长度,由 hidden、meta 等字段计算出来)。
- dst_pe ← dst_rank(远端 PE / rank)。
- qp_id ← dst_expert_local_idx(选择用于和远端通信的 RC QP 的 id,通常按目标本地专家索引选 QP,便于并发)。
- lane_id ← lane_id(warp 中的 lane,0..31)。内部把不同 lane 用来并行构造多条 WQE(把大消息分 chunk 并分配给 lanes)。
- message_idx ← slot_idx(或在 combine/send 场景中传 token_idx - offset 作为消息索引)。该参数被传给 ibgda_submit_requests 用于决定何时调用 doorbell(post send)策略 / 批次逻辑。
传输
首先就是准备QP(这里get RC的操作就是去拿QP,具体见 6. ibgda_get_rc),再会把数据分块,一次 RDMA 操作不能跨越不同的 MR。把总的bytes数按照 3. ibgda_get_lkey_and_rkey的规则来切分。
__device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(...) {
auto qp = ibgda_get_rc(dst_pe, qp_id);
auto remaining_bytes = bytes;
while (remaining_bytes > 0) {
if (lane_id == num_wqes) {
my_chunk_size = min(remaining_bytes,
ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey, qp->dev_idx));
}
// 广播给warp内所有lane
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
remaining_bytes -= chunk_size;
req_lptr += chunk_size;
req_rptr += chunk_size;
++num_wqes;// WQE 数量 +1
}
}接着,开始构造每个WQE,见ibgda_write_rdma_write_wqe实现。
uint64_t base_wqe_idx = 0;
if (lane_id == 0)
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); // 预留连续的 WQE 槽位
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); // 广播给所有 lanes
if (lane_id < num_wqes) {
auto wqe_idx = base_wqe_idx + lane_id;
auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); // 获取 WQE 的内存地址
ibgda_write_rdma_write_wqe(
qp,
my_laddr, my_lkey, // 本地地址 + lkey
my_raddr, my_rkey, // 远程地址 + rkey
my_chunk_size, // 传输字节数
wqe_idx, // WQE 索引
&wqe_ptr // WQE 内存地址
);
}构造完毕后,会用syncwarp,同步每个warp。每个warp的第一个线程把所有WQE提交给rdma,详细见ibgda_submit_requests
if (lane_id == 0)
ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);总体流程
┌─────────────────────────────────────────────────────────────┐
│ nvshmemi_ibgda_put_nbi_warp(req_rptr, req_lptr, bytes, ...)│
└────────────────────┬────────────────────────────────────────┘
│
┌────────────▼────────────────────┐
│ 1. 计算分块(Chunking) │
│ - 遍历剩余字节 │
│ - 每个 lane 计算一个 chunk │
│ - 调用 ibgda_get_lkey_and_rkey │
└────────────┬────────────────────┘
│
┌────────────▼──────────────────────┐
│ ibgda_get_lkey_and_rkey │
│ - 查找本地 lkey(constmem.lkeys) │
│ - 查找远程 rkey(constmem.rkeys) │
│ - 计算物理地址 │
│ - 返回 chunk 大小 │
└────────────┬──────────────────────┘
│
┌────────────▼─────────────────┐
│ 2. 预留 WQE 槽位 │
│ - lane 0 调用 atomicAdd │
│ - 广播 base_wqe_idx │
└────────────┬─────────────────┘
│
┌────────────▼───────────────────┐
│ 3. 构造 WQE(并行) │
│ - 每个 lane < num_wqes 填充 │
│ - ibgda_write_rdma_write_wqe │
│ * Control Segment │
│ * Remote Address Segment │
│ * Data Segment │
└────────────┬───────────────────┘
│
┌────────────▼────────────────┐
│ 4. 提交 WQE │
│ - __threadfence() │
│ - atomicCAS ready_idx │
│ - ibgda_post_send │
│ * ibgda_update_dbr │
│ * ibgda_ring_db │
└────────────┬────────────────┘
│
┌────────────▼────────────────┐
│ 5. NIC 硬件执行 RDMA Write │
│ - 读取 WQE │
│ - DMA 从 laddr 读取数据 │
│ - 通过 IB 网络发送 │
│ - 远端 NIC 写入 raddr │
└────────────┬────────────────┘
│
┌────────────▼──────────────┐
│ 6. 完成通知(CQE) │
│ - NIC 写入 CQ │
│ - nvshmemi_ibgda_quiet │
│ 轮询 CQ 等待完成 │
└───────────────────────────┘2. nvshmemi_ibgda_amo_nonfetch_add
设备端通过IBGDA RC QP发起的 “不取值” 的原子加法(atomic operation no fetch add),常用于远程直接访问内存。
- 如果是本地的拷贝,说明rptr是本地的指针,直接在本地就执行加法操作。
- 远程操作的话,找到目标pe的QP,然后拿到需要的rkey和raddr。
- ibgda_reserve_wqe_slots会去拿qp→mvars→tx_wq.resv_head,这个就是发送队列(TX WQ)的“预留”,表示一个warp还没写完的WQE索引。你在就是要写新的WQE就需要去原子地把resv_head 向前推进 1 。并拿回返回值作为新的WQE起始索引。
- 写入AMO请求
ibgda_write_amo_add_wqe,将目标内存的 rptr 地址和加法的 value 传递给 RDMA 设备,触发远程加法操作。 - 最后,提交 wr给RDMA,实际执行加法操作。
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(
void* rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
if (is_local_copy) {
atomicAdd(static_cast<unsigned long long*>(rptr), value);
} else {
nvshmemi_ibgda_device_qp_t* qp = ibgda_get_rc(pe, qp_id);
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey, qp->dev_idx);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void* wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf), qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
}3. ibgda_get_lkey_and_rkey
输入本地和远程的VA,拿到lkey,rkey,物理地址和chunksize。
4. ibgda_write_rdma_write_wqe
填充对端的字段:
raddr_seg.raddr = HtoBE64(raddr); // 远程物理地址(大端序)
raddr_seg.rkey = rkey; // 远程 rkey填充本地的数据字段:
data_seg.byte_count = HtoBE32(bytes); // 传输字节数
data_seg.lkey = lkey; // 本地 lkey
data_seg.addr = HtoBE64(laddr); // 本地地址填充控制信号的字段:
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); // QP 编号 + DS(段数量)
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; // 完成时更新 CQ
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE); // 操作码:RDMA Write然后把上面三个结构体写入到WQE memory:
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));5. ibgda_submit_requests
step1:内存屏障
__threadfence(); // 确保所有 WQE 写入对 NIC 可见step2:原子级更新ready_idx
unsigned long long int* ready_idx = state->use_async_postsend ?
qp->tx_wq.prod_idx : &mvars->tx_wq.ready_head;
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx)
; // 等待之前的 WQE 都被填充完毕step3:通过CAS循环,可以让多个warp按照顺序提交。然后再去写入doorbell,NIC就会知道有新的WQE需要处理。(doorbell是特殊的内存映射寄存器,MMIO)NIC会从Send Queue取出WEQ执行RDMA操作。
if (!state->use_async_postsend) {
constexpr int kNumRequestInBatch = 4;
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
ibgda_post_send(qp, new_wqe_idx);
}在ibgda_post_send内会去更新 DBREC,再写入 BlueFlame 寄存器。再之后NIC 硬件执行 RDMA Write(DMA 从 laddr 读取数据→过 IB 网络发送→远端NIC写入raddr),完成时NIC写入CQ,外面会用nvshmemi_ibgda_quiet来轮询CQ完成状态。
__device__ static __forceinline__ void ibgda_post_send(nvshmemi_ibgda_device_qp_t* qp, uint64_t new_prod_idx) {
nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars;
uint64_t old_prod_idx;
ibgda_lock_acquire(&mvars->post_send_lock);
old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
if (new_prod_idx > old_prod_idx) {
// 1. 更新 DBREC(Doorbell Record)
ibgda_update_dbr(qp, new_prod_idx);
// 2. Ring Doorbell(写入 BlueFlame 寄存器)
ibgda_ring_db(qp, new_prod_idx);
}
ibgda_lock_release(&mvars->post_send_lock);
}6. ibgda_get_rc
选择rc的主要因素是pe和qp_id,nvshmem内在 nvshmemi_ibgda_device_state_t 结构体内的 globalmem 结构体内有所有PE,所有RC,所有GPU的QP顺序排好的数组rcs。
__device__ static __forceinline__ nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
auto state = ibgda_get_state();
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
return &state->globalmem
.rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)];
}7. ibgda_poll_cq
deepep搬运缩减了nvshmem的nvshmem的轮询cq的实现,逻辑是:消费者索引小于nic上的真实索引的条件下,一直轮训硬件的cq队列,确认该cq条目已经完成并准备好被消费。
- 先取出了一下三个重要变量:cq→cons_idx(消费者索引)、cq→cqe→wqe_counter(真实硬件完成队列)、cq→ncqes是cq队列条目总数。
idx - wqe_counter - 2 < ncqes这里通过uint16的边界65535(-1),来让当前条件变为idx - 1一定完成。外面调用ibgda_poll_cq的时候传的idx本来就是当前wqe的next index,所以idx在这里面通过这个条件就可以知道当前wqe确保已经完成。- memory_fence_cta() 确保所有的内存访问操作都是有序的
__device__ static __forceinline__ void ibgda_poll_cq(nvshmemi_ibgda_device_cq_t* cq, uint64_t idx) {
const auto cqe64 = static_cast<mlx5_cqe64*>(cq->cqe);
const uint32_t ncqes = cq->ncqes;
memory_fence_cta();
if (*cq->cons_idx >= idx)
return;
uint16_t wqe_counter;
do {
wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter));
} while ((static_cast<uint16_t>(static_cast<uint16_t>(idx) - wqe_counter - static_cast<uint16_t>(2)) < ncqes));
*cq->cons_idx = idx;
memory_fence_cta();