ncclP2pSchedule
前面都是算一些基础需要的变量。在核心逻辑内:
- 由sendGroup和recvGroup做第一层循环
static ncclResult_t ncclP2pSchedule(struct ncclComm* comm) {
struct ncclNodeRanks* nodeRanks = comm->nodeRanks; // 按节点存的rank信息
int groupSize = (comm->nNodes > 1) ? ncclParamGroupSize() : comm->maxLocalRanks;// 默认8个rank一个group
int local = comm->localRank % groupSize; // 当前rank在group内的id
int group = comm->localRank / groupSize; // 当前rank的group id
int nGroups = comm->nRanks / groupSize; // 总group数量
int nGroupsPow2 = pow2Up(nGroups); // 假如nGroups是7,那么nGroupsPow2就是8
// 假如前面ncclParamGroupSize()指定的是4个rank一组的话,那么下面的代码就会算出来
// groupToNode[x] : 第x个group在哪个节点
// groupToLocal[x] :第x个group在当前节点上组内的起始rank index
int *groupToNode, *groupToLocal;
for (int n = 0; n < comm->nNodes; ++n) {
int nGroupsInNode = comm->nodeRanks[n].localRanks / groupSize;
for (int g = 0; g < nGroupsInNode; ++g) {
groupToLocal[groupCount] = g * groupSize;
groupToNode[groupCount] = n;
groupCount++;
}
if (n < comm->node) group += nGroupsInNode;
}
/* 核心逻辑 */
uint32_t groupRound = 0, groupDelta = 0; int round = 0;
do {
if (groupDelta < nGroups) { // Filter nonsensical group deltas
int sendGroup = (group + groupDelta) % nGroups;
int recvGroup = (group - groupDelta + nGroups) % nGroups;
int sendNode = groupToNode[sendGroup];
int recvNode = groupToNode[recvGroup];
for (int delta = 0; delta < groupSize; delta++) {
int sendLocal = groupToLocal[sendGroup] + (local + delta) % groupSize;
int recvLocal = groupToLocal[recvGroup] + (local - delta + groupSize) % groupSize;
comm->p2pSchedule[round].sendRank = nodeRanks[sendNode].localRankToRank[sendLocal];
comm->p2pSchedule[round].recvRank = nodeRanks[recvNode].localRankToRank[recvLocal];
comm->p2pSchedule[round].sendNode = sendNode;
comm->p2pSchedule[round].recvNode = recvNode;
round += 1;
}
}
groupRound += 1;
groupDelta = (groupDelta + groupRound) & (nGroupsPow2 - 1); // Quadratic update
} while (groupRound != nGroupsPow2);
}