跳至内容
ZFM
返回

How to configure AstroPaper theme

Table of contents

Open Table of contents

Cutlass开发

前置知识:[[深度学习基础]]、[[cuda开发]]、[[GPU架构代际演进]]、[[GPU存储层级]]

行列基础概念

a[M][N] 这种写法,本质上就是 a[row][col],对应的坐标系就是 a[y][x]

其次,行主序代表x方向连续,列主序代表y方向连续(存储连续)

维度名称坐标变量对应 GPU 索引物理内存行为步进 (Stride)
行方向(水平方向)x (Column)blockIdx.x连续访问 (Coalesced)1
列方向(垂直方向y (Row)blockIdx.y跳跃访问Width

注:行方向和行号不同,行方向是维度方向,而行号代表坐标,由idx.y表示

​ 列方向和列号不同,列方向是维度方向,而列号代表坐标,由idx.x表示

block分工

在高性能计算(HPC)的矩阵乘法实现中,主流算子库(如 cuBLAS、CUTLASS)通常采用以结果矩阵C为导向的并行任务划分策略。

其核心逻辑是将目标矩阵 C 划分为一系列互不重叠的计算单元(Tile),并将每个线程块(Block)与一个特定的C-tile强绑定:

CuTe详解

layout

CuTe Layout 是 NVIDIA Cutlass 库的核心抽象,其本质是一个从多维逻辑坐标空间一维线性内存偏移的映射函数。它巧妙地利用 C++ 模板递归与元编程技术,将复杂的嵌套张量结构通过“逐层降维”的方式,还原为基础的地址偏移。

一个典型的 Layout 由 Shape(形状)Stride(步长) 两个核心部分组成:


1. 核心计算原理:内积映射

Layout 的基本操作是将逻辑坐标向量 P 与步长向量 S 进行内积运算,从而计算出物理偏移量 Offset。

对于一个简单的二维 Layout:

任意坐标 (i, j) 的物理地址计算公式为: $$ Offset = i \cdot s_0 + j \cdot s_1 $$ 示例:

若 Layout 为 Shape: (2, 3)Stride: (3, 1),则逻辑坐标 (1, 2) 对应的偏移量为: $$ 1 \cdot 3 + 2 \cdot 1 = 5 $$ 即该元素在内存中的实际地址为 &A + 5


2. 高阶嵌套与递归降维

CuTe 的强大之处在于支持分级嵌套(Hierarchical Layouts)。对于经过多次映射或重塑的复杂 Layout,其 Shape 和 Stride 在结构上保持严格的一一对应(Congruent)。

复杂嵌套示例:

在这种结构中,内积运算以递归方式展开。计算深层嵌套坐标的偏移时,CuTe 会逐层解析嵌套关系,直到触达最底层的标量步长。

3. Layout 的组合与退化

嵌套 Layout 并非孤立存在,它们可以通过**合并(Flattening)归约(Reduction)**在不同逻辑视图间转换:

  1. 分级合并:上述嵌套 Shape ((2, (2, 2)), (2, (2, 2))) 在逻辑上等价于 ((2, 4), (2, 4))
  2. 维度压缩:进一步合并可视为 (8, 8) 的二维平面,最终在总容量上退化为 (64) 的一维线性空间。
  3. 进制还原:这种映射关系类似于高维进制系统。通过 Stride 的指引,我们可以将一个一维线性索引(如第 35 个元素)反向推导出其在高维空间中的多级嵌套坐标。

在 CuTe Layout 的体系中,将一维线性索引(Logical Index)还原为多维分级嵌套坐标(Hierarchical Coordinate)的过程,本质上是一个基于基数(Radix)的递归解构过程

Layout_LeftLayout_Right 两种情况下的读取优先级和内存跨步(Stride):

场景一:Layout_Left(类似于 Col-Major,列主序)

核心法则:从**最左侧、最内层(Left-most, Inner-most)**的元素开始,向右解析。最左侧的维度变化最快,内存连续。

对于 ((4, 8), (7, 2)),在 Layout_Left 的法则下,读取优先级的顺序是:先 4 -> 再 8 -> 再 7 -> 最后 2

逐步推导 Stride 的过程

  1. 最先读取 4 (Mode 0):它是最左侧的维度,因此它在内存中是连续的。
    • Stride_0 = 1
  2. 接着读取 8 (Mode 1):每当“4”循环完一圈,才轮到“8”进一位。所以跨度是前面维度的总大小。
    • Stride_1 = Stride_0 $\times$ 4 = 1 $\times$ 4 = 4
  3. 接着读取 7 (Mode 2):跨入右边的 Tuple 了。
    • Stride_2 = Stride_1 $\times$ 8 = 4 $\times$ 8 = 32
  4. 最后读取 2 (Mode 3):变化最慢的维度。
    • Stride_3 = Stride_2 $\times$ 7 = 32 $\times$ 7 = 224

结论:如果指定为 Layout_Left,CuTe 自动生成的布局等价于 Layout<Shape<<4, 8>, <7, 2>>, Stride<<1, 4>, <32, 224>>>先读 4。

场景二:Layout_Right(类似于 Row-Major,行主序 / C 语言数组标准)

核心法则:从**最右侧、最内层(Right-most, Inner-most)**的元素开始,向左解析。最右侧的维度变化最快,内存连续。 对于 ((4, 8), (7, 2)),在 Layout_Right 的法则下,读取优先级的顺序被彻底反转:先 2 -> 再 7 -> 再 8 -> 最后 4逐步推导 Stride 的过程

  1. 最先读取 2 (Mode 3):它是最右侧的维度,因此内存连续。
    • Stride_3 = 1
  2. 接着读取 7 (Mode 2)
    • Stride_2 = Stride_3 $\times$ 2 = 1 $\times$ 2 = 2
  3. 接着读取 8 (Mode 1):跨入左边的 Tuple 的最右侧。
    • Stride_1 = Stride_2 $\times$ 7 = 2 $\times$ 7 = 14
  4. 最后读取 4 (Mode 0):变化最慢。
    • Stride_0 = Stride_1 $\times$ 8 = 14 $\times$ 8 = 112 结论:如果指定为 Layout_Right,CuTe 自动生成的布局等价于 Layout<Shape<<4, 8>, <7, 2>>, Stride<<112, 14>, <2, 1>>>根本不是先读 4 或 8,而是先读 2!
还原算法:递归解构(Recursive Deconstruction)

当一个线性索引 L 进入一个 Shape 为 (d_0, d_1, …, d_n) 的 Layout 时:

  1. 左侧优先取余:第 0 维坐标 $$ i_0 = L \pmod{d_0} $$ 这是因为在列主序中,最左边的维度变化最快。
  2. 步进整除:更新索引 $$ L’ = \lfloor L / d_0 \rfloor $$ 用于计算后续维度的坐标。
  3. 递归嵌套:如果 d_0 本身是一个嵌套Shape (s_0, s_1),则将 i_0 作为新的线性索引,重复上述步骤进入下一层括号内部。

实例推导:从线性索引到嵌套坐标

假设我们有一个嵌套 Shape:

$$ Shape: ((2, 4), 8) $$ 我们需要还原线性索引 L = 35 对应的嵌套坐标 $$ ((i_{0,0}, i_{0,1}), i_1) $$

第一层解构(最外层括号):

最外层视作二维结构: $$ d_{outer0} = (2,4) = 8(乘积),d_{outer1} = 8 $$

第二层解构(进入内层括号 (2, 4)):

现在对 $$ L_{inner} = 3 $$ 在 Shape (2, 4) 中进行还原:

最终结果:

线性索引 35 还原后的分级嵌套坐标为:((1, 1), 4)

4. 维度增长的优先级

在 CuTe 中,维度增长的优先级(即谁变动得快)是从左到右的,但这种增长是嵌套发生的。这种“左到右”不仅存在于最外层的括号,也递归地存在于每一个子括号内部。

对于给出的坐标 A((1, (0, 1)), (0, (0, 1))): 我们可以将其视为 A(Row, Col),其中:

在这个层级,Row 整体变动最快。这意味着物理地址上,A(0, 0) 后面紧跟着的是 A(1, 0),而不是 A(0, 1)

换个方式理解:在层次化布局中,坐标的嵌套遵循‘由外向内’的导航逻辑:右侧的分量决定了数据所在的宏观‘街区’(大块)而左侧的分量则在选定的街区内进行微观‘寻址’(细分单元)

img

以上图为例,coord: ((1, 3), (2, 4))所表示的位置可以按这样的方式找到:首先提取坐标各维度的右侧分量组成宏观坐标 (3, 4),如同图中的红色标注,在全局范围内锁定对应的“大块”位置(即第 4 个块行与第 5 个块列的交汇处),随后进入该区块内部,利用左侧分量组成的细分坐标 (1, 2),如绿色标注所示,在局部格点中精准定位到具体的逻辑元素。

5. _的切片原则

取layout时,可以使用_来选择整个维度,并结合数字来取出自己想要的layout:

可以把 _ 直接理解成“这一维我不把它钉死,而是要把它整条保留下来”:其中_,代表固定维度,通常代表结果的维度。数字表示选定维度,有几个数字通常代表结果少了几维;

核心本质就一句话:数字负责定位,_ 负责保留;_ 不是随便写的符号,而是在说“这一维不要取单点,要取完整视图”,所以看一个切片表达式,不用先管它多复杂,只要先数有几个 _,就能立刻知道最后拿到的是“点、线,还是面”。通过切片,可以一次性将需要的数据全部取出,如可以固定thread,将thread需要用的数据全部取出,不需要复杂的映射关系计算

image-20260308222329753

6. layout的查看逻辑

看一个 layout 本质上就是看“下标每增加一步,物理地址会跳到哪里”。如果默认按二维来看,想看某一列里有哪些元素,就把列固定住,让行的各层下标按顺序变化,然后用 stride 算出它们相对起点的偏移;想看某一行也是同理,只不过换成固定行、让列变化。

在Layout<Shape <Shape <_2,_2,_2>,Shape <_2,_2,_2>>,

​ Stride<Stride <_1,_16,_4>,Stride<_8,_2,_32>>>下

比如你想看 Row 内部是怎么跳的,固定列为 0: $$ Row(0,0,0) \to 偏移 0\ Row(1,0,0) \to 偏移 1(Stride 1)\ Row(0,1,0) \to 偏移 16(Stride 16)\ Row(0,0,1) \to 偏移 4(Stride 4)\ $$

结论: 在行方向上,它是先走 1 步,再跳 4 步,最后猛跳 16 步。

说明虽然逻辑上都只是“行坐标在变”,但物理地址并不是老老实实连续增长,而是会按 stride 规则跳到不同位置;如果列不是 0,再额外加上这一列本身带来的固定偏移就行。所以不管是二维变二维,还是一维拆成多维,本质都一样:逻辑坐标的递增是平滑的(0, 1, 2, 3…),但地址的跳跃是由 Stride 决定的

一般来说,对于实际的开发,可以将layout统一作为二维理解,任何 layout 都看成:行维度 × 列维度,这两个维度可以是单层stride,也可以是嵌套的跳跃


CuTe Layout 概念详解


一、Layout 兼容性(Compatibility)

什么叫”兼容”?

两个 Shape 兼容,需要满足两个条件:

  1. size 相等
  2. A 的所有坐标都是 B 的合法坐标

直觉上:A 能”放进” B 里。

24  compatible with  (4,6)    ✅  因为 4×6=24,且0~23都是(4,6)的合法坐标
24  compatible with  ((2,2),(3,2))  ✅  因为 2×2×3×2=24
(4,6)  NOT compatible with  24   ❌  虽然size相等,但(4,6)的坐标是2D的,24只接受1D坐标

💡 关键:兼容是单向的。A 兼容 B,不代表 B 兼容 A。这就像”正方形是矩形,但矩形不是正方形”。


二、Layout 接受多种坐标

这是最重要的概念。一个 Layout 同时接受多种坐标系

以 Shape (3,(2,3)) 为例,size = 3×2×3 = 18,它同时接受三种坐标:

1-D坐标2-D坐标自然坐标(Natural)
0(0,0)(0,(0,0))
1(1,0)(1,(0,0))
3(0,1)(0,(1,0))
6(0,2)(0,(0,1))
16(1,5)(1,(1,2))

同一行的三个坐标完全等价,映射到同一个内存偏移。

为什么要有这么多坐标系?

同一块内存,按需解读,零开销!


三、坐标映射:怎么从任意坐标到自然坐标?

规则:列主序(colex order),从右往左数进位(和我们日常从左往右的”字典序”相反)。

以 Shape (3,(2,3)) 举例,用2D坐标 (row, col) 遍历,col 先变:

col=0  col=1  col=2  col=3  col=4  col=5
(0,0)  (1,0)  (2,0)  (0,1)  (1,1)  (2,1) ...
  0      1      2      3      4      5   ...  ← 1D坐标

1D 坐标 6 → 2D 坐标 (0,2) → 自然坐标 (0,(0,1))


四、Index 映射:从自然坐标到内存偏移

就是内积:自然坐标 × Stride,各分量相乘后求和。

(3,(2,3)):(3,(12,1)) 为例:

自然坐标 (i,(j,k)) → 偏移 = i*3 + j*12 + k*1


五、Layout 操作速查

Sublayout(取子布局)

Layout a = (4,(3,6)):(1,(4,12))

layout<0>(a)4:1        // 取第0个mode
layout<1>(a)(3,6):(4,12)  // 取第1个mode(还是个tuple)
layout<1,0>(a)3:4        // 取第1个mode的第0个子mode

select vs take

// select:按索引挑选,可以不连续
select<1,3>(a)   // 取第1和第3个mode

// take:取连续范围 [begin, end)
take<1,3>(a)     // 取第1、第2个mode

group & flatten(分组与展平)

Layout a = (2,3,5,7):(1,2,6,30)

group<0,2>(a)((2,3),5,7):((1,2),6,30)  // 把mode 0~1 打包成一个嵌套mode
flatten(...)(2,3,5,7):(1,2,6,30)      // 展平回去

用处:把矩阵重新解读成向量,或把向量重新解读成矩阵,不移动数据


Coalesce 详解

考虑两个相邻的 mode:s0:d0s1:d1

规则1:s0:d0 ++ _1:d1 => s0:d0

size=1 的 mode 无论 stride 是多少,坐标永远是0,贡献永远是 0 * d1 = 0。直接丢掉。

_3:_5  ++  _1:_99  =>  _3:_5

规则2:_1:d0 ++ s1:d1 => s1:d1

同上,前面那个 size=1 的也丢掉。

规则3:s0:d0 ++ s1:(s0*d0) => s0*s1:d0

当后一个 mode 的 stride = 前一个 mode 的 size × stride 时,可以合并。

直觉:这说明两个 mode 在内存里是连续排列的,就像列主序矩阵。

_2:_1  ++  _4:_2  =>  _8:_1

验证:s0=2, d0=1, s1=4, d1=2s0*d0 = 2*1 = 2 = d1 ✅,合并为 8:1

规则4:没法合并,保持原样

_2:_1  ++  _4:_3  =>  (_2,_4):(_1,_3)

例子

Layout<Shape<_2, Shape<_1, _6>>,
       Stride<_1, Stride<_6, _2>>>

展平(flatten)成一系列相邻的 mode:

_2:_1,  _1:_6,  _6:_2

第一步_2:_1 ++ _1:_6 → 规则1:后面 size=1,丢掉 _1:_6 → 剩下 _2:_1

第二步_2:_1 ++ _6:_2 → 检查规则3:s0*d0 = 2*1 = 2 = d1 → 合并为 _12:_1

结果:_12:_1


By-mode Coalesce(保持维度)

有时候我们需要保持 Layout 的维数。比如我有个2D layout,希望结果还是2D。

auto a = Layout<Shape<_2, Shape<_1,_6>>,
                Stride<_1, Stride<_6,_2>>>{};

coalesce(a, Step<_1,_1>{})  // 结果是 (_2,_6):(_1,_2)

Step<_1,_1> 的意思是:“把这个 layout 当成2D的,分别对每个维度做 coalesce”

最终合并两维:(_2,_6):(_1,_2)

Coalesce = 把一个复杂的 Layout 化简成最少 mode 数、等价的简单 Layout。就像化简分数一样,值不变,形式更简单,计算更高效。

CuTe Layout Composition

核心思想

还记得我们说过 Layout 是整数到整数的函数吗?有了这个视角,Composition 就很自然:

R = A ∘ B,即 R(c) = A(B(c))

就是普通的函数复合。复合的结果还是一个 Layout,即:A,B为layout,A∘B输入为两个layout,Composition输出也为一个layout,不过该layout复合了AB的layout,隐去了A 的逻辑坐标,使得新layout输入B 的逻辑坐标可以直接得到物理地址

A ∘ B 就是 “让 B 当坐标生成器,让 A 当地址计算器”

结果是一个能直接映射真实内存地址的新 Layout(也就是得到 B 能直接映射真实地址的 layout)

而由于 stride 恒正,B 的逻辑坐标从 0 开始单调递增,对应到 A 里的物理地址也是从小到大走,所以 A∘B ,B取值越小,取出的永远是 A 物理地址偏移越小的tile,而不是中间某块或末尾某块,故一般需要和A∘(B, B*)结合使用,详见下文


怎么计算 Composition?

关键性质:左分配律

A ∘ (B₀, B₁, ...) = (A ∘ B₀, A ∘ B₁, ...)

所以可以把 B 拆成各个 sublayout,分别和 A 复合,最后拼起来。

这样问题化简为:A(多模)和 s:d(单个整数mode)怎么复合?

计算 A ∘ s:d 的两步走:A 是多维的时候

A = (6,2):(8,2)

现在要算 A ∘ 4:3,即从 A 中每隔3步取,共取4个

R(0) = A(0)  = 0
R(1) = A(3)  = 24
R(2) = A(6)  = 2
R(3) = A(9)  = 26

问题来了:这4个结果能写成一个 Layout 吗?怎么找到它的 shape 和 stride?


Step 1:除法 /——“每隔 d 步”

(6,2) / 3 的含义是:把步长3从整个 shape 里”除出去”,从左往右逐维处理。

过程:

(6,2) / 3 = (2,2)

对比几个例子感受规律

(6,2) / 2  →  第0维 6÷2=3,消耗完  →  (3,2)
(6,2) / 3  →  第0维 6÷3=2,消耗完  →  (2,2)
(6,2) / 6  →  第0维 6÷6=1,消耗完  →  (1,2)
(6,2) / 12 →  第0维 6÷6=1,还剩2,第1维 2÷2=1  →  (1,1)

当 d > 第0维的 size 时,才会”溢出”到第1维。


stride 怎么变?

除法只影响被消耗到的那一维的 stride,规则是乘以”残余的 d”。

(6,2):(8,2) / 3:
  d=3 全部被第0维消耗(6÷3=2)
  第0维 stride × 3 = 8×3 = 24
  第1维不动
  → (2,2):(24,2)

(6,2):(8,2) / 12:
  d=12,第0维 size=6,先消耗6,剩余 d=12÷6=2
  第0维 stride × 6 = 8×6 = 48,shape变为1
  剩余 d=2 进入第1维,2÷2=1
  第1维 stride × 2 = 2×2 = 4,shape变为1
  → (1,1):(48,4)

Step 2:取模 %——“共取 s 个”

Step 1 告诉我们”每隔3步的无限序列长什么样”,但我们只要前4个

(2,2):(24,2) % 4 的含义:从这个无限序列里截取前4个元素

shape (2,2) 总共有 2×2=4 个元素,恰好等于 s=4,所以直接保留:

(2,2):(24,2) % 4  =  (2,2):(24,2)

更直接地说:

(6,2) % 3 = (3,1)    → 只需要3个,第0维取3,第1维取1
(6,2) % 6 = (6,1)    → 只需要6个,第0维取满6,第1维取1
(6,2) % 12 = (6,2)   → 需要12个,全取

规律:s 先填满第0维,填满了再溢出到第1维


Composition 的实际用途

用途1:把1D layout 重塑成矩阵

20:2  ∘  (5,4):(4,1)

“把连续步长为2的20个元素,重新解读成5×4的行主序矩阵”

= (20:2 ∘ 5:4,  20:2 ∘ 4:1)
= (5:8, 4:2)          ← 都是trivial情况:s:d ∘ s':d' = s':(d*d')
= (5,4):(8,2)

结果:第一维步长8,第二维步长2,正是行主序5×4矩阵的访问模式。

用途2:By-mode Composition(按维度复合)

Tiler 分别对每个维度施加不同的 sublayout:

// a 是一个 (12,(4,8)):(59,(13,1)) 的layout
auto tiler = make_tile(Layout<_3,_4>{},   // 对第0维:3:4
                       Layout<_8,_2>{});  // 对第1维:8:2

auto result = composition(a, tiler);
// 等价于:
make_layout(composition(layout<0>(a), 3:4),
            composition(layout<1>(a), 8:2));

含义:从 a 的第0维中每隔4步取3个,从第1维中每隔2步取8个,得到一个 3×8 的子块。

用途3:Shape 作为 Tiler(stride=1 的特殊情况)

auto tiler = make_shape(Int<3>{}, Int<8>{});
// 等价于 make_tile(3:1, 8:1)
// 含义:取第0维前3个,取第1维前8个(连续子块)

Composition = “用 B 描述的访问模式去索引 A”,结果是一个新 Layout,直接表达了这种复合访问。 这是 CuTe 里 tiling、partitioning 等所有高级操作的基础。

Complement补集

Complement 的核心目的

Complement 到底做什么?

Layout complement(Layout A, Shape cotarget)

3 个铁律(post-conditions)

  1. R 的大小不会超过 cotarget(不会越界)
  2. R 的 stride 是递增的、正的 → 结果唯一
  3. R 和 A 完全没有重叠(不踩到已经选走的元素)

例子

例1
A = 4:1(连续 4 个元素)
cotarget = 24(总共 24 个元素)
complement = 6:4
意思:把 A 重复 6 次,每次间隔 4,就把 24 个位置全覆盖了!

例2
A = 6:4(每 4 个跳一个)
cotarget = 24
complement = 4:1
把 A 的“空洞”正好用连续 4 个填满。

例3
A = (4,6):(1,4)(标准的 4×6 矩阵)
cotarget = 24
complement = 1:0(啥都不用加,已经满了)

例4
A = (2,2):(1,6)
cotarget = 24
complement = (3,2):(2,12)

complement1

A∘B 的输入输出分析

在学习 product 和 Division 之前,我们需要对复合映射 A∘B 的定义域与值域建立更严谨的理解。

基本结构

A∘B 本质上是一个复合函数 A(B(input)),既然是函数,自然具有对应的定义域与值域:

Layout的注意点

CuTe Layout Algebra是纯粹的代数操作,像是composition,B 的输出直接作为 A 的输入按 stride 继续计算,没有任何隐式的边界检查或取模处理:

$$ A(B(\text{input})) $$ 这里有一点需要特别注意:所有的 layout 都不会对输入做任何合法性检查。 CuTe 的 layout 计算本质上就是一个纯粹的数学公式: $$ index = \sum (coord_d \times stride_d) $$ 无论是 composition 还是单独的 layout,当你输入一个超出其定义域的值时,它不会报错,而是会代入公式继续按照 stride 规律计算并返回一个结果——只不过这个结果是一个越界的非法物理地址。这种行为贯穿所有 layout 操作,其后果只有在运行时通过 cuda-memcheck 才能捕获,可能表现为 Shared Memory 越界访问或非法内存地址错误。因此,保证输入合法性完全由使用者自己负责,因此每次取用数据,最好用size()检查定义域,cosize()检查最大允许物理地址。

根据 B 的值域与 A 的定义域的大小关系,存在三种情况:

B 的值域小于 A 的定义域 遍历 B 的所有输出,无法覆盖 A 的完整定义域,A 的部分物理地址永远不会被访问到。

B 的值域等于 A 的定义域 B 的输出与 A 的定义域恰好一一对应,整个复合函数是严格的单值函数,每个物理地址被访问恰好一次,是最理想的状态。

B 的值域大于 A 的定义域 B 的输出超出 A 的定义域范围后,A 继续代入公式计算,产生越界的非法物理地址。


结合 A* 来看

对于 A*∘B∘C∘... 这条复合链,A* 始终作为最后一层翻译器,将任意输入转换为以 A-tile 为单位的起始偏移。无论前面经过多少层复合,只要最终输出经过 A*,结果永远是合法的 tile 起始偏移,A* 对前面整条复合链完全透明。

而这条复合链能否不重不漏地覆盖所有坐标,唯一的约束条件就是整条链是否构成单值函数,等价于链上每一步都没有发生遗漏或越界。

因此 (A, A*∘B∘C∘...) 的本质可以归纳为:A 负责 tile 内部的寻址,后面这条复合链无论有多长多复杂,只要保持单值函数的性质,就能精确覆盖所有 tile 的起始偏移。两者组合后即可不重不漏地访问整个空间。

换句话说:形如 (A, A*∘B∘C∘...) 的 layout,只要满足两个条件:其一,A*∘B∘C∘... 整条复合链是单值函数;其二,最终输出经过 A* 的翻译。那么这个 layout 就能保证 tile 内的每一个元素都被不重不漏地映射到对应的物理地址,并方便地通过下标访问。

这一分析是理解后续 product 和Division规则的基础

Division


Division到底想解决什么问题?

对于一个tile布局,divide目的就是切片,最终得到一个layout,我们可以通过这个layout,用简单的下标形式获取每个tile里每个值的真实物理地址,而且对于这个tile,我想要在这个tile上面进行新的layout布局,也可以通过divide继续获得直达的物理地址映射,如此一直下去

所以 logical_divide(A, B) 的结果,应该同时包含:

  1. tile 内部的布局
  2. tile 之间的布局

公式(divide到底干了什么)

$$ A \oslash B := A \circ (B, B^*) $$

第一步:A ∘ B是什么

$$ A \circ B $$ 就是:

在大布局 A 上,取出由 B 描述的 tile,就是 composition,准确来说,

A ∘ B 就是 “让 B 当坐标生成器,让 A 当地址计算器”

结果是一个能直接映射真实内存地址的新 Layout(也就是得到 B 能直接映射真实地址的 layout)


第二步:B*

B*B 的 complement,它的作用是标记”第几个 tile”,也就是将A分成若干个B,每个 B* 取值对应一个 stride 偏移,表示当前 tile 相对于第一个 tile 的起始位置偏移量。


第三步:(B, B*)

把这两个拼起来: $$ (B, B^*) $$ 就是一个两层结构,结合stride:

然后再喂给 A: $$ A \circ (B, B^)=(A \circ B,A \circ B^) $$ 这就是 divide,他借助了A的中转,可以直接通过下标获取各个tile(由B*确定)里各个值(由B确定)的直达物理地址,最后的结果是二维的,准确的说是2-mode,第一个mode管一个tile内部的取值,第二个mode管不同tile的首地址位置

为什么用A∘(B,B*)而不是A∘B?

由 Complement 的定义可以知道,A∘B 只建立了一个 tile 和 A 之间的映射关系,它自然无法覆盖 A 的所有元素。而 (B, B*) 配合使用则能将 A 完整切分,由于(B, B*)的size和A相等,所以映射函数满足一一对应关系,由此可以用”第 n 个 tile 中第 m 个元素”的方式不重不漏地索引到每一个元素。因此,A∘(B, B*) 才是真正完整的映射关系,它让 A 的每个物理地址都有对应的逻辑坐标可以访问到。


实现

template <class LShape, class LStride,
          class TShape, class TStride>
auto logical_divide(Layout<LShape,LStride> const& layout,
                    Layout<TShape,TStride> const& tiler)
{
  return composition(layout, make_layout(tiler, complement(tiler, size(layout))));
}

logical_divide相当于把上述的三步一起做了,得到的直接是tile对于物理地址的layout


例子: $$ A = (4,2,3):(2,1,8) $$


$$ tiler B = 4:2 $$

$$ B^* = (2,3):(1,8) $$ 得到: $$ (B, B^) = (4,(2,3)):(2,(1,8)) $$ $$ A \circ (B, B^) $$

结果: $$ ((2,2),(2,3)):((4,1),(2,8)) $$ **为什么tile内部的排列形式变了:**由于做了Composition,相当于直接跳过了A取真实地址,所以((2,2)):(4,1),由于A本身就是对物理地址的真实排列做了加工的,而原来的 tile是按照A的逻辑坐标来顺序取值的,所以真实的地址当然会和原来的不同,相当于把A的加工也给还原回去了。 尽管底层物理排列发生了变化,但在编程接口层,CuTe 依然维持了逻辑上的易用性。开发者仍可沿用 (4, (2,3)) 这种直观方式进行取值,系统会自动将线性索引 4 解构为对应的多维坐标 (2,2),从而在享受物理寻址效率的同时,屏蔽了底层的地址转换细节


对于 2D 布局 A: $$ A = (9,(4,8)):(59,(13,1)) $$ 如果按一个 2D 的 tiler 去切: $$ B = <3:3,\ (2,4):(1,8)> $$ 这里的意思是:

按照上面同样的三步,得到 2D divide 的结果:

其中Rest项更多的是表示当前是第几个tile,合起来就是: $$ ((TileM,RestM), (TileN,RestN)) $$ 这其实很“规整”,因为它保留了原来矩阵的 M / N 语义。

所以 logical_divide 的优点之一就是:

虽然它切了 tile,但它没有把原来的维度语义打乱。

M 还是 M,N 还是 N。 只是每个维度内部被拆成“tile 内部”和“tile 外部”。


对于不同的函数,改变的只是结果的表示形式

对于logical_divide结果是: $$ ((TileM,RestM), (TileN,RestN)) $$

对于zipped_divide结果是: $$ ((TileM,TileN), (RestM,RestN)) $$ 也就是:

变成:


对于tiled_divide结果是: $$ ((TileM,TileN), RestM, RestN, …) $$ 适合你想单独保留每个 tile 维度编号的时候。


对于flat_divide结果是: $$ (TileM, TileN, RestM, RestN, …) $$ 适合某些需要最简单 rank 结构的场景。


Layout Shape : (M, N, L, ...)
Tiler Shape  : <TileM, TileN>

logical_divide : ((TileM,RestM), (TileN,RestN), L, ...)
zipped_divide  : ((TileM,TileN), (RestM,RestN,L,...))
tiled_divide   : ((TileM,TileN), RestM, RestN, L, ...)
flat_divide    : (TileM, TileN, RestM, RestN, L, ...)
((TileM,TileN), (RestM,RestN,L,...))

Product

基本概念回顾

在理解乘积之前,先回顾几个关键概念:


逻辑乘积(logical_product)的定义

数学定义

logical_product(A, B) 的结果同样是一个二模(2-mode)布局,定义为: $$ A \otimes B := (A, A^* \circ B) $$ 其中:

C++ 实现

template <class LShape, class LStride,
          class TShape, class TStride>
auto logical_product(Layout<LShape,LStride> const& layout,
                     Layout<TShape,TStride> const& tiler)
{
  return make_layout(
    layout,
    composition(complement(layout, size(layout) * cosize(tiler)), tiler)
  );
}

Product的本质与A*范围

为了便于理解,我们将 layout 记为 A,tiler 记为 B,那么complement操作的 cotarget 就是 size(A) * cosize(B)

先理解 A* 的职责

A* ∘ B 的含义是:按照 B 规定的顺序,依次取出每个 A-tile 在内存中的首地址偏移。换句话说,取多少个 tile、按什么顺序取,完全由 B 说了算

因此 A* 必须考虑最坏的情况——当 B 索引到它值域中最大的那个 tile 编号 n 时,A* 也要能正确响应。即便 B 并不会把 [0, n] 范围内的每一个 tile 都取到(中间可能有”跳过”的情况),A* 仍然必须按照最大可能的需求来预留空间,而不能心存侥幸。

再理解 product 的本质

product(A, B) 并不是在一块已经固定好的内存里重新分配现有的元素,而是:以 A 的排布样式为模板,根据 B 所需的副本数量和顺序,构造出多份互不重叠的 A 的复制体。

这些复制体会优先”填入” A 原本布局中未被实际占用的空洞(即 A 内部因步长等原因留下的间隙),空洞填满之后,再向后开辟新的连续空间。

本质上,product 是在定义一套新的 layout 映射关系,而不是搬动或重新分配底层数据。

最终结论

两者相乘,得到的正是:一个足以容纳这些复制体、并且尽量利用空洞的 codomain 目标范围——既充分利用了 A 内部的空洞,又保证了在最坏情况下不会越界。


一维示例详解

基本示例:A = (2,2):(4,1),B = 6:1

A 的含义

B 的含义

计算过程

第一步:求补集

complement(A=(2,2):(4,1), N=6×4=24)
→ A* = (2,3):(2,8)

A* 表示在 24 个元素的空间中,A 重复出现的起始偏移布局:

第二步:组合

composition(A*=(2,3):(2,8), B=6:1)
→ (2,3):(2,8)

因为 B=6:1 是连续的 6 个,而 A* 已经恰好描述了 6 个 tile(2×3=6),所以结果不变。

第三步:拼接

logical_product(A, B) = ((2,2),(2,3)):((4,1),(2,8))

结果是一个 rank-2 布局:

product1.png

改变 tile 排列顺序:B = (4,2):(2,1)

logical_product(A=(2,2):(4,1), B=(4,2):(2,1))

product2.png


二维乘积与 rank-sensitive 变换

朴素方式的问题

直接用 logical_product 对二维 Layout 操作时,A 的行列放一起,R =A* ∘ B的行列也放一起,形如: $$ ((A_r,A_c),(R_r,R_c)) $$ 但我们更想要行维度放一起,列维度放一起,因此就有了下面的不同product:

blocked_product(块状乘积)

$$ ((A_r,R_r),(A_c,R_c)) $$

先走 tile 内部,再走 tile 编号,所以 tile 会保持成“块”

blocked_product 还会自动对 Mode 0 进行 coalesce(合并),进一步简化布局。

productblocked2d.png

raked_product(耙状/交织乘积)

$$ ((R_r,A_r),(R_c,A_c)) $$

先走 tile 网格,再走 tile 内部,所以 tile 会被打散成交错分布

productraked2d.png


zipped_product / tiled_product / flat_product

这三种变体是在 logical_product 的基础上对输出模式进行重排,以满足不同的使用场景:

Layout Shape : (M, N, L, ...)
Tiler Shape  : <TileM, TileN>

logical_product : ((M,TileM), (N,TileN), L, ...)
zipped_product  : ((M,N), (TileM,TileN,L,...))
tiled_product   : ((M,N), TileM, TileN, L, ...)
flat_product    : (M, N, TileM, TileN, L, ...)
变体适用场景
logical_product基础操作,需要手动处理模式
zipped_product需要明确区分”tile 内”和”tile 间”两个视角
tiled_product常见的分块迭代场景,tile 间维度保持独立
flat_product需要完全线性化的简单情形

下面我把这份 CuTe Tensor 教程重新写一遍,不按原文硬翻,而是按“为什么需要它 → 它到底是什么 → 怎么创建 → 怎么访问 → 怎么切片 → 怎么分块 → 为什么 GEMM 里离不开它”这条真正适合学习的主线来讲。核心依据是 CuTe 的 TensorLayoutLayout Algebra 文档。


Tensor

Tensor = 数据来源 + 索引规则

更准确地说:

Tensor = Engine + Layout

其中:

这样一来,算法就不用关心底层存储细节了。你写的仍然是“我要访问 (m,n)”,至于它是在 gmem、smem、寄存器,还是某种特殊迭代器上生成出来的数据,Tensor 帮你统一掉。

Tensor 不是“存数据的多维数组”这么简单。 它本质上是一个“用 Layout 解释数据”的视图/容器。


Tensor 和 Layout 的关系

因此你可以把 Tensor 看成:“带数据的 Layout”

Engine

Engine 可以理解成一个对“迭代器/数组/指针”的轻量包装。

它最核心提供的是:

using iterator
using value_type
using reference
iterator begin()

也就是说,Engine 本质上只需要能告诉 Tensor:

通常用户不用手写 Engine,你直接 make_tensor(...),CuTe 会自动帮你构造合适的 engine。常见的有:

所以,Engine 的职责不是搞复杂抽象,它只是帮 Tensor 统一“数据入口”。


Tensor 怎么创建

这里分成两大类:


Nonowning Tensor

最常见的形式就是:

Tensor A = make_tensor(ptr, layout);
Tensor A = make_tensor(make_gmem_ptr(A), layout);
Tensor A = make_tensor(make_smem_ptr(s), layout);

意思是:

ptr 指向的数据,再配上 layout 这套解释规则,构造一个 Tensor 视图。

其中CuTe 提供:

一维 view

float* A = ...;
Tensor tensor_8 = make_tensor(A, make_layout(Int<8>{}));

带 stride 的一维 view

Tensor tensor_8d2 = make_tensor(A, 8, 2);

二维 global memory Tensor

Tensor gmem_8sx16d = make_tensor(make_gmem_ptr(A),
                                 make_shape(Int<8>{},16));//默认 LayoutLeft

shared memory Tensor

Layout smem_layout = make_layout(make_shape(Int<4>{},Int<8>{}));
__shared__ float smem[decltype(cosize(smem_layout))::value];
Tensor smem_4x8_col = make_tensor(make_smem_ptr(smem), smem_layout);

这个例子里最该注意的是:

decltype(cosize(smem_layout))::value

cosize 是 CuTe 的函数,计算一个 Layout 的 codomain 大小,也就是这个 Layout 实际需要占用多少内存空间

关键点:cosize 返回的不是普通的 int,而是一个类型!因为 shape 和 stride 都是静态的(Int<4>, Int<8>),所以 CuTe 在编译期就能算出结果,返回一个 Int<32> 类型的对象decltype 是 C++ 关键字,用来在编译期提取一个表达式的类型,而不实际执行它,因此:

decltype(cosize(smem_layout))  // → Int<32>(返回值的类型)

最后通过::value取出值:

decltype(cosize(smem_layout))::value
//   ↓               ↓           ↓
// 拿到类型      Int<32>       取出32
__shared__ float smem[decltype(cosize(smem_layout))::value];//定义shared memory大小

Owning Tensor:自己拥有数据

如何创建:

Tensor rmem = make_tensor<float>(Shape<_4,_8>{});

make_tensor_like

Tensor rmem_4x8_like = make_tensor_like(rmem_4x8_pad);

Owning 和 Nonowning 的本质区别

// make_tensor:传入了已有指针 → 非拥有型(视图)
make_tensor(ptr, layout)        // ptr 是已有内存,我只是看它

// make_tensor<T>:没有指针,自己分配 → 拥有型
make_tensor<float>(shape)       // 自己创建内存

因此,Owning 创建大多数用于寄存器创建,他会自动销毁并回收内存


Tensor 访问元素

Tensor A = make_tensor<float>(
  Shape<Shape<_4,_5>, Int<13>>{},
  Stride<Stride<_12,_1>, _64>{});

Tensor B = make_tensor(b_ptr, make_shape(13,20));

这个 A 的 shape 是:

((4,5), 13)

意味着它第一大 mode 里又嵌了两个子 mode。 所以你既可以写:

A[make_coord(make_coord(m0,m1), n)]

也可以写:

A(m,n)

前者是自然层级坐标,后者是兼容坐标。


Tensor tiling

很多 Layout 的分块运算也能直接用在 Tensor 上:

composition(Tensor, Tiler)
logical_divide(Tensor, Tiler)
zipped_divide(Tensor, Tiler)
tiled_divide(Tensor, Tiler)
flat_divide(Tensor, Tiler)

但是Layout 可以做 product,Tensor 不行。

因为 product 可能把 codomain 扩得更大,导致访问跑到原来 Tensor 有效边界之外,造成内存越界


切片_(Slicing)

切片(Slicing)和普通访问的核心区别在于:普通访问 A(coord) 返回一个元素,而切片 A(_, 5) 返回一个子 Tensor(视图)。触发切片的关键不是用没用 operator(),而是你有没有传入 _

_ 可以直接理解成”这一维我不把它钉死,而是要把它整条保留下来”,类似 Matlab/Fortran 里的 :。数字表示固定坐标,_ 表示保留的维度:


切片本质上同时做两件事:

切片不是单纯”少看几维”,而是同时修改了起始指针和布局(Layout)

  1. 算出新 Tensor 的起点:把固定坐标(数字部分)喂给原 Layout,得到一个偏移量,加到原 iterator 上,新 Tensor 的数据起点就指向原 Tensor 中对应切片起始位置的地址;
  2. 构造新 Tensor 的 Layout:把传了 _ 的那些 mode 对应的 Layout mode 重新拼成一个新 Layout。

所以切片返回的不是”拷贝后的数组”,而是一个 view新 Tensor = 新起点 + 新布局


结果的 rank 取决于你在哪个层级写 _

这是最容易出错的地方。_ 保留的是你当前写出来的那个层级的 mode,而不是简单地”保留几个数字”。切片不仅关心”留了什么元素”,还关心”以哪个层级把这些元素留下来的”。

以 shape 为 ((3,2), (2,5,2)) 的 Tensor A 为例:

写法含义结果 shape说明
A(2, _)第一大 mode 固定为 2,第二大 mode 整体保留(2,5,2)保留完整的第二大 mode
A(_, 5)第一大 mode 整体保留,第二大 mode 固定坐标 5((3,2))第一大 mode 作为整体的 1 个 mode 被保留
A(make_coord(_,_), 5)第一大 mode 的两个子 mode 分别保留,第二大 mode 固定坐标 5(3,2)内部两个子 mode 被展开为 2 个独立 mode

A(_, 5)A(make_coord(_,_), 5) 包含的元素完全相同,但 rank 和 shape 不同——前者 rank 为 1,后者 rank 为 2。因此,结果 Tensor 的 rank,等于切片坐标中 _ 的个数(更准确地说,是当前层级保留了几个 mode,结果就有几个对应的 mode)。


核心本质就一句话:数字负责定位,_ 负责保留。

看一个切片表达式,不用先管它多复杂,只要先数有几个 _,就能立刻知道最后拿到的是点、线还是面。通过切片,可以一次性将所需数据全部取出——例如固定 thread 坐标,将该 thread 需要用的数据整块取出,完全不需要复杂的映射关系计算,这也是 local_tilelocal_partitionTiledMMA 等高层抽象的底层基础。


Partition:本质上是“先divide,再切片”

Inner partition:

Tensor A = make_tensor(ptr, make_shape(8,24));  // (8,24)
auto tiler = Shape<_4,_8>{};                    // (4,8)

Tensor tiled_a = zipped_divide(A, tiler);       // ((_4,_8),(2,3))

这个结果是典型的divide操作:

((_4,_8),(2,3))

意思是:

如果某个 CTA 想拿其中一个 tile,比如 (blockIdx.x, blockIdx.y) 对应那个 tile,就写:

Tensor cta_a = tiled_a(make_coord(_,_), make_coord(blockIdx.x, blockIdx.y));

也就是:

**每个 CTA 负责一整块数据。**如果掌握了divide操作,应该不难理解


Outer partition:取出每个tile相同位置的元素

还是同一个 tiled_a

Tensor thr_a = tiled_a(threadIdx.x, make_coord(_,_)); // (2,3)

这次操作完全反过来了:

所以 thr_a 的 shape 是 (2,3),因为它遍历的是所有 tile。


Thread-Value partition:通过线程选值

普通 partition 的思路是:

  1. 先按几何形状切 tile
  2. 再让线程从 tile 里拿自己的那部分

但有时候线程拿数据的方式根本不是简单矩形。 比如 MMA 指令规定:

这个分配模式常常非常“奇怪”,不是简单地“thread 0 拿第一行”。

这时 CuTe 用的是:

先构造一个 TV-layout:(thread, value) -> 数据坐标

image-20260311143653021

例如:

auto tv_layout = Layout<...>{};   // (T,V) -> (M,N)
Tensor A = make_tensor<float>(Shape<_4,_8>{}, LayoutRight{});
Tensor tv = composition(A, tv_layout);   // (8,4)
Tensor v  = tv(threadIdx.x, _);          // 每个线程拿 4 个值

第一步:你先定义一个“线程-值分配表”

tv_layout 规定:

tv_layout.png

第二步:以这张表为输入,来取相应值

tv = composition(A, tv_layout)

使用A∘B ,把原来按 (m,n) 看的数据,变成按 (thread,value) 来看。

第三步:每个线程切出自己那一行

v = tv(threadIdx.x, _)

于是这个线程直接拿到“我该处理的那几个值”。

所以 TV partition 的本质是:

先显式描述“线程和数据的对应关系”,再让 Tensor 自动给你每个线程的局部视图。

TV Layout具体例子

对于TV Layout,对于了解其作用很有必要,因此,单开一节,用例子讲透:

SM80_8x8x4_F64F64F64F64_TN为例:

对于SM80_8x8x4_F64F64F64F64_TN的布局要求,画出相应的布局要求如下(更多的布局信息详见[[MMA不同架构数据排布]]): ![[SM80_8x8x4_F64F64F64F64_TN.png]] 其在Cute中对应的源码定义如下(include/cute/atom/mma_traits_sm80.hpp):

// (T32,V1) -> (M8,N8)
using SM80_8x4 = Layout<Shape<Shape<_4, _8>, _1>, Stride<Stride<_8, _1>, _0>>;
// (T32,V2) -> (M8,N8)
using SM80_8x8_Row = Layout<Shape<Shape<_4, _8>, _2>, Stride<Stride<_16, _1>, _8>>;
template <>
struct MMA_Traits<SM80_8x8x4_F64F64F64F64_TN>
{
  using ValTypeD = double;
  using ValTypeA = double;
  using ValTypeB = double;
  using ValTypeC = double;

  using Shape_MNK = Shape<_8, _8, _4>;
  using ThrID     = Layout<_32>;
  using ALayout   = SM80_8x4;
  using BLayout   = SM80_8x4;
  using CLayout   = SM80_8x8_Row;
};

在写 TVLayout 之前,你必须先在大脑中建立一个坐标系转换的概念。TVLayout 本质上是一个降维映射函数。它的作用是:

底层逻辑前提(极度重要): 无论是 A、B 还是 C 矩阵,在 CuTe 的逻辑视角里,它们首先是一个多维张量(比如 C 是 $8 \times 8$),然后默认按照列主序(Column-Major)展平为一维空间。TVLayout 就是要算出:任意一个 $(T, V)$ 组合,在这个展平后的一维空间里,到底排在第几个位置?

第一阶段:正向映射(解析 SM80_8x4 与 SM80_8x8_Row)

就像之前说的,在 CuTe 中,所有的 TVLayout 都是一个将 $(T, V)$ 坐标映射为 一维内存偏移量 (1D Offset) 的函数。结合矩阵的主序(Majorness),这个 1D Offset 就能还原回逻辑坐标 $(Row, Col)$。 对于 Warp 级别的 MMA,线程总数固定为 32。所以线程坐标 $T$ 永远在 $[0, 31]$ 之间。 设 $T$ 的一维 ID 拆解后为 $(t_0, t_1)$,其代数关系固定为:

模式A 矩阵要求B 矩阵要求PTX 后缀
NTM-major (列主序)N-major (行主序).col.row
TNK-major (行主序)K-major (列主序).row.col
NNM-majorK-major.col.col
TTK-majorN-major.row.row
N向下连续,T向右连续
1. A / B 矩阵的 Layout (SM80_8x4)

代码定义:

// (T32,V1) -> (M8,N8)  <-- 注:此处的注释 N8 有误导性,实际用于 A 是 8x4,用于 B 是 4x8
using SM80_8x4 = Layout<Shape<Shape<_4, _8>, _1>, 
                        Stride<Stride<_8, _1>, _0>>;

输入:$(t_0, t_1)$ 以及 $V=0$(因为 _1 代表只有一个值)。 一维偏移公式:$Offset = t_0 \times 8 + t_1 \times 1 + V \times 0 = t_0 \times 8 + t_1$。 【应用到 Matrix A(左下角)】

【应用到 Matrix B(右上角)】

2. C 矩阵的 Layout (SM80_8x8_Row)

代码定义:

using SM80_8x8_Row = Layout<Shape<Shape<_4, _8>, _2>, 
                            Stride<Stride<_16, _1>, _8>>;

输入:$(t_0, t_1)$ 以及 $V \in [0, 1]$(因为是 _2,每个线程负责两个值)。 一维偏移公式:$Offset = t_0 \times 16 + t_1 \times 1 + V \times 8$。 为了方便寻找规律,我们提取公因数 8:$Offset = (2 \times t_0 + V) \times 8 + t_1$。 【应用到 Matrix C(右下角)】


第二阶段:如何看着图片手写 TV Layout

假设你现在只有右上角的 B 矩阵图片,你要如何从零写出 SM80_8x4Step 1: 确定矩阵维度和期望的连续性 (主序) 看右上角的 B 矩阵图:它有 4 行,8 列($4 \times 8$)。 结合我们上面讲的逻辑,虽然它承担着 $K$ 维连续的重任,但从这张 $4 \times 8$ 的图片表象来看,数据的一维 Offset 增长方向是横向的。我们据此得出,用于逆推的目标一维索引公式等效于 行主序 (Row-Major):$Offset = Col + Row \times 8$。 Step 2: 从图片中提取 $(Row, Col)$ 与 $T$ 的代数关系 我们观察图片中的数字规律:

Step 3: 将关系代入目标一维索引公式 将 $Row = t_0$ 和 $Col = t_1$ 代入 Step 1 的公式:$Offset = t_1 + t_0 \times 8$ 为了对应 CuTe 要求的输入格式,我们按 $(t_0, t_1)$ 的顺序重写等式:$Offset = t_0 \times 8 + t_1 \times 1$ Step 4: 提取 Stride,写出代码 在 CuTe 的标准模板 Shape<Shape<_4, _8>, _V>

组合起来得到代码: Layout<Shape<Shape<_4, _8>, _1>, Stride<Stride<_8, _1>, _0>> 这与官方源码分毫不差。

第三阶段:实例验证

为了确保理论 100% 成立,我们用代码推导的坐标去图片里“查字典”,如果严丝合缝,逻辑就无懈可击。 验证 1:Matrix A 中的 T23

验证 3:Matrix C 中的 T5 (包含寄存器 V 切片)


CuTe partition 机制的底层架构与执行流解析

在 CuTe 的编程模型中,partition 并非单纯的数据结构拆分,而是一次基于代数复合与降维的编译期坐标变换。其核心逻辑可以拆解为以下四个递进的阶段:

1. 逻辑映射层的确立 (TVLayout / TVCLayout)

首先,我们需要定义线程与逻辑数据之间的映射关系。

2. 物理排布与代数复合 (Layout Composition)

底层内存(如 Shared Memory 或 Global Memory)本身拥有其基础的物理排布规则(即输入的底层 Layout),这里面包含了复杂的错位(Swizzle)防冲突机制或跨步(Stride)规则。这些规则都是以 $(M, N, K)$ 逻辑坐标作为输入的。

由于 TVLayout 的输出正好是逻辑坐标,因此可以直接将其作为算子,叠加(Compose)在底层物理 Layout 之上。这就形成了一个高阶的复合映射函数:直接从 $(T, V)$ 穿透到真实的物理内存地址。

3. 线程切片与静态降维 (Partition / Slicing)

这是发生质变的一步。当我们调用 partition 时,实际上是将当前执行的线程 ID(如 threadIdx.x)作为常量,硬编码代入到上述的复合函数中。

这一步在数学上称为切片(Slicing)或取陪集(Coset),其工程意义在于彻底抹除了空间中的“线程维度”。原本需要二维甚至三维坐标系来描述的寻址逻辑,在此刻坍缩为一个只与数据量大小(Size/Value)相关的极简一维函数。

4. 一维顺序执行 (1D Sequential Access)

无论最终是读取(Load)还是存储(Store),经过 partition 吐出的结果(如 Tensor Fragment),在当前线程的视角下,本质上退化成了一个“包含所有目标物理地址的 1D List”。

此时,复杂的线程并行概念已经完全消解。线程只需要按照这个“List”的 Size,通过一维索引 (i) 进行最简单的顺序迭代计算。底层编译器会利用提前算好的代数偏移量,直接生成极其高效的内存访问指令,实现榨干 Tensor Core 的极致性能。

总结:对于 partition 的理解,可以这样梳理:首先定义 TVLayout,即描述线程 T 与寄存器 V 的 ID 映射关系,通过这个 layout 可以确定当前线程需要从以 MNK 逻辑排布的内存中取哪个地址的数据;与此同时,TVCLayout 则进一步描述当前线程 ID 所计算出的结果应存放在哪个位置。值得注意的是,swizzle 等操作本身以 MNK 逻辑排布的内存作为输入,完成从逻辑地址到物理地址的映射,因此 TVLayout 可以直接叠加在其之上。partition 的本质,就是将传入的底层 layout(如经过 swizzle 变换后的物理内存布局)作为基础,在其上叠加一层 TVLayout,再根据当前线程的 ID 进行切片,最终得到该线程所需访问的全部数据。无论是读还是写,partition 后得到的结果都相当于一个地址列表——其中记录了该线程需要读取或存放的所有物理地址。通过 () 运算符按序访问时,线程的概念已经完全消解,剩下的仅是按 size 顺序排列的一组物理地址,线程只需依次读写即可。


Tensor algorithms

copy

表面上,copy(src, dst) 就是把源张量拷到目标张量。 但在 CuTe 里,这件事远不只是“挨个赋值”:

所以 copy 的真正目标不是“拷贝字节”,而是:

在逻辑上把一个 Tensor 的元素,按同一组逻辑坐标,搬到另一个 Tensor 中;至于底层怎么搬,交给 CuTe 根据类型决定。

其有两个主要重载:

copy(src, dst);
copy(copy_atom, src, dst);

第一个版本让 CuTe 自己根据 src/dst 的类型选默认实现。 第二个版本允许你显式指定 Copy_Atom,也就是告诉 CuTe:这次别自己猜了,我指定你用哪种拷贝原子。


copy 不是“当前线程自己做完就结束”

copy 的并行性和同步语义,取决于参数类型。它有可能:

所以你绝不能想当然地认为:

“我一调用完 copy,后面立刻读 dst 就一定安全。”

不一定。 如果这次 copy 是多线程协作的,那你往往需要 __syncthreads(); 如果这次 copy 底层走的是 cp.async,那还需要额外的 async fence / wait。

copy_if

假设你把一个 41×55 的矩阵切成 4×8 tile。 那边界 tile 一定有“残缺块”——并不是每个 tile 都满。 如果你还像整齐 tile 那样无脑去 load/store,就会越界。

CuTe 对这个问题的标准处理不是去构造“10 个满 tile + 1 个残 tile”这种复杂逻辑,而是:

仍然把它当成规则 tile 网格,只不过给每个元素附一个 predicate,越界元素不做操作。

这就是 copy_if 的意义。它和 copy 一样拿 src/dst,但会额外拿一个和它们同 shape 的 predicate Tensor。predicate 非零的位置才执行拷贝。

template <class PrdTensor, class SrcEngine, class SrcLayout, class DstEngine, class DstLayout>
CUTE_HOST_DEVICE void copy_if(PrdTensor const& pred, Tensor<SrcEngine, SrcLayout> const& src, Tensor<DstEngine, DstLayout>& dst)
{
  using SrcType = typename SrcEngine::value_type;
  using DstType = typename DstEngine::value_type;

  CUTE_UNROLL
  for (int i = 0; i < size(dst); ++i)
  {
    if (pred(i))
    {
      dst(i) = static_cast<DstType>(static_cast<SrcType>(src(i)));
    }
  }
}

gemm

CuTe 把“乘法+归约”这件事统一成一组 按 mode 解释的张量算法。 于是 gemm 不只覆盖普通矩阵乘,还统一覆盖:


V / M / N / K

cute 规定:

CuTe 统一采用

也就是说,B 不是“按线性代数课本的二维写法”来表述,而是始终把 归约维 K 放在右边

这样做的好处是:

五种 gemm 语义

向量乘加:(V) x (V) => (V)

公式是: $$ C_v += A_v B_v $$


外积:(M) x (N) => (M,N)

$$ C_{mn} += A_m B_n $$

一个长度为 M 的向量和一个长度为 N 的向量,直接张成一个 M×N 矩阵。


矩阵乘:(M,K) x (N,K) => (M,N)

$$ C_{mn} += A_{mk} B_{nk} $$

相当于对每个 K,调用一次外积,再把结果累加。


batched 外积:(V,M) x (V,N) => (V,M,N)

每个 v 自己对应一组 A/B(向量),产生自己的一个 (M,N) 结果。

你可以把 V 想成 batch 维。 于是:


batched 矩阵乘:(V,M,K) x (V,N,K) => (V,M,N)

对每个 batch 的 (M,K)(N,K) 做 GEMM。它会对每个 K 调用batched 外积操作


gemmcopy

gemmcopy 一样,会根据 Tensor 参数类型自动派发到合适实现;而且也支持额外传一个 MMA_Atom 来覆盖默认选择。

这背后的统一设计思想是:

高层算法接口保持不变,低层执行原子可替换。


axpby

axpby 的定义是: $$ y = \alpha x + \beta y $$

写成矩阵方式: $$ C = \alpha \cdot A + \beta \cdot C $$

fill

fill 把输出 Tensor 的每个元素都写成同一个标量值。

clear

clearfill(0) 的专门版本:把输出 Tensor 全部置零。

CuTe MMA

CuTe MMA Atoms

CuTe 通过一对结构体将每条 MMA 指令暴露给通用 CUDA C++ 代码:

Operation 结构体负责描述 PTX 指令的物理接口——它定义指令所需的参数,几乎不依赖任何上层软件抽象(无 Layout、Tensor 或自定义数值类型),仅关注指令本身的输入输出。不同的 Operation 结构体以其对应 MMA 指令的功能命名。

MMA_Traits 结构体以 Operation 类型为模板参数进行特化,提供该 Operation 的元信息,包括:逻辑计算类型、操作的逻辑形状,以及操作内部线程与数据的 Layout 描述。CuTe 为每种受支持的 Operation 提供了对应的 Traits 特化实现。

两者共同构成一个 “Atom”(原子),将线程与数据布局的复杂性与 PTX 指令的调用点彻底解耦。Traits 所暴露的信息只与单次 MMA 操作的语义相关,与其实际运行的硬件粒度无关。

CuTe 目前支持以下硬件级别的 MMA 原子:

硬件级别代表架构
单线程(FMA 指令)通用
Quad-pairVolta
单 WarpAmpere
WarpgroupHopper

这一设计的核心价值在于语义与实现的分离:上层代码只需面向统一的 Atom 接口编程,无需关心底层指令在不同架构上的调度差异。

Operation结构体

Operation结构体的命名

例如,下面Volta部分会引用include/cute/arch/mma_sm70.hpp中定义的SM70_8x8x4_F32F16F16F32_NT Operation结构体。

• “SM70”指Volta。
• “8x8x4”指M=8、N=8、K=4,即quadpair执行的MMA操作维度(见下文)。这反映在PTX中为.m8n8k4.
• “F32F16F16F32”指四个矩阵操作数A、B、C、D的元素类型。MMA计算D = C + A * B,所以从左到右读取类型:D是F32(float),A是F16(half),B是F16(half),C是F32(float)。这反映在PTX指令名中为.f32.f16.f16.f32
• “NT”表示PTX指令为A输入M-major(不转置,列主序)和B输入N-major(转置,行主序)。这反映在PTX指令名中为.col.row.

内容

Operation结构体有四个公开类型别名:DRegistersARegistersBRegistersCRegisters
例如,SM70_8x8x4_F32F16F16F32_NT结构体定义如下:

using DRegisters = float[8];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[2];
using CRegisters = float[8];

这显示了每个线程将向PTX指令传递的每个矩阵A、B、C、D的值的数量。对于这个Operation,每个线程为C和D各传递8个F32值(因此float[8]),为A和B各传递4个F16值(因此uint32_t[2];指令把两个16位F16值打包进每个32位uint32_t)。

fma静态成员设备函数
Operation结构体定义一个公开的static void fma函数。它用CUTE_HOST_DEVICE宏标记,添加__host__ __device__注解。不同的Operation根据PTX MMA指令定义不同参数数量的fma。实现用宏保护PTX指令的使用,如果宏未定义则引发assert。这保证即使PTX指令不可用,使用该Operation的Atom的测试和示例仍能编译。

Traits

内容
MMA_Traits特化定义以下公开类型别名。

ValTypeD:D矩阵的逻辑计算类型
ValTypeA:A矩阵的逻辑计算类型
ValTypeB:B矩阵的逻辑计算类型
ValTypeC:C矩阵的逻辑计算类型
Shape_MNK:MMA操作的逻辑MxNxK形状
ThrID:单个MMA操作内的逻辑线程映射(指定线程、quadpair、warp或warpgroup视图)
ALayout:(thread,value)对到MxK A矩阵坐标的映射
BLayout:(thread,value)对到NxK B矩阵坐标的映射
CLayout:(thread,value)对到MxN C矩阵坐标的映射

例子
SM70_8x8x4_F32F16F16F32_NT Operation的MMA_Traits特化位于头文件include/cute/atom/mma_traits_sm70.hpp。它看起来像这样:

template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
{
  using ValTypeD = float;
  using ValTypeA = half_t;
  using ValTypeB = half_t;
  using ValTypeC = float;

  using Shape_MNK = Shape<_8,_8,_4>;
  using ThrID   = SM70_QuadPair;
  using ALayout = SM70_8x4_Col;
  using BLayout = SM70_8x4_Col;
  using CLayout = SM70_8x8_32b;
};

Volta

本节以及后续章节展示如何构造MMA原子。

Volta架构实现了一个HMMA指令,由8个线程组成的一个quadpair(QP)协作共享数据并执行8x8x4(fp32或fp16)的矩阵乘累加。(由于一个warp有32个线程,它会跨4个QP执行一个16x16x4的MMA。)

我们首先看看如何把HMMA指令的ISA线程和数据分区语义编码到Traits结构体中。HMMA NT指令的线程-数据布局如下:

HMMA.8x8x4.NT.png

类型
HMMA NT使用的类型为:

  using ValTypeD = float;
  using ValTypeA = half_t;
  using ValTypeB = half_t;
  using ValTypeC = float;

MMA_Traits的其余部分都以这些类型为单位描述。

形状
HMMA NT的形状为8x8x4:

  // MMA的逻辑形状
  using Shape_MNK = Shape <_8,_8,_4>;

线程ID
如果warp的32个线程逻辑索引为[0 … 31],则上图包含线程[0,1,2,3]U[16,17,18,19]。这些线程构成第0个quadpair。我们可以写一个线程映射,把MMA的8个逻辑线程id [0,1,2,3,4,5,6,7]映射到warp中quadpair的线程索引[0,1,2,3]U[16,17,18,19]。布局函数有4个元素步长为1,另外2个步长为16。于是我们写出表示quadpair的布局:

  // (逻辑线程id) -> (线程索引)
  using ThrID = Layout<Shape <_4, _2>,
                       Stride<_1,_16>>;

这个布局函数将MMA操作的逻辑线程id [0,8) 映射到warp中quadpair的线程索引 [0,4)U[16,20)。

累加器映射
让我们精确看看QP内8个线程如何映射到A、B和C矩阵。对于C和D矩阵,上图进一步拆解如下。左边是整个QP视图,右边是仅线程0拥有的值。

HMMA.8x8x4.quadpair.C.png

这个单指令级别视图的元信息正是我们想在CuTe中编码的内容。具体来说,上图中的QP级别视图对应SM70_F32F16F16F32的四个MMA traits。这些结构体包含元素类型、Shape_MNK以及我们上面构造的ThrID映射。现在来看CLayout——累加器的线程-数据布局。CLayout的任务是在(logical_thr_id, logical_val_id)与C矩阵的(m, n)坐标之间建立映射,以便后续构建更复杂的布局和操作(如16x16x4 WMMA)。

我们从上图开始构造CLayout。和任何CuTe布局一样,它是Shape与对应Stride的组合。先看形状。我们知道HMMA使用8个线程,每个线程拥有8个值。因此映射的形状在两个模式上大小必须为8:

  // (T8,V8) -> (m,n)
  using CLayout = Layout<Shape <_8, _8>,
                         Stride<_?, _?>;  // Stride待填充

不是C矩阵的逻辑8x8形状,而是8线程×8值的形状。现在我们要把它映射到(m,n)坐标。因为CuTe布局返回索引而非坐标,我们选择列主序编码(m,n)坐标:

(logical_thr_id, logical_val_id) -> (m, n) == m + n * M

现在开始构造CLayout的stride。先看线程间的stride。注意:

我们发现模式可以转录为布局。8个线程的位置为:

  using CLayout = Layout<Shape <Shape <_2,  _2, _2>, _8>,
                         Stride<Stride<_1, _16, _4>, _?>;

用完全相同的方法构造logical value id模式的stride:

  // (T8,V8) -> (m,n)
  using CLayout = Layout<Shape <Shape <_2, _2,_2>, Shape <_2,_2, _2>>,
                         Stride<Stride<_1,_16,_4>, Stride<_8,_2,_32>>>;

这样就完成了!我们可以验证每个(tid,vid)坐标都能可靠映射到正确的编码(m,n)坐标。

对于该CLayout

输入 (Input)

它的输入是一个二元组(Logical Thread ID, Logical Value ID)

  • Logical Thread ID (tid):逻辑线程编号(在这个例子中是0到7)。
  • Logical Value ID (vid):该线程持有的寄存器(数据)编号(在这个例子中是0到7)。

输出 (Output)

它的输出是一个线性索引 (Linear Index)

这个索引代表了该数据在 8x8 矩阵中的一维偏移量,通常映射到二维坐标 (m, n):

$$ Index = m + n \times 8 $$ 通过这个输出,程序就能知道:“第 i 个线程的第 j 个寄存器值,对应矩阵里第 m 行、第 n 列的元素。”

如果是F16累加器,布局要简单得多。每行累加器(m, :)由单个线程持有,布局为:

  using CLayout = Layout<Shape <_8,_8>,
                         Stride<_1,_8>>;

A和B布局映射
A和B矩阵布局取决于源是否转置。下图显示NT和TN转置情况下A和B矩阵的线程ID到数据所有权映射。

HMMA.8x8x4.quadpair.AB.png

先看TN布局的A矩阵(图右侧)。同样是8个逻辑线程,但每个线程只拥有4个元素。ALayout的形状为Shape<_8, _4>。stride方面,我们需要(m, k) == m + k * M的映射。沿M模式看,从(T0, V0)(T1, V0)步长为1(对所有8线程)。沿K模式,从(T0, V0)(T0, V1)步长为8(对4个值)。因此A布局为:

  // (T8,V4) -> (m,k)
  using ALayout = Layout<Shape <_8,_4>,
                         Stride<_1,_8>>;

B布局类似,只是我们为了方便写成(N,K)而非(K,N)。最终B布局与A相同:

  // (T8,V4) -> (n,k)
  using BLayout = Layout<Shape <_8,_4>,
                         Stride<_1,_8>>;

NT情况的布局更复杂(图左侧)。沿A的M模式,先看到T0的4个值,再看到T4的4个值。因此M模式有两段子stride。对于K模式,简单递增thr_id,val_id不变,步长为8(对4线程)。于是A布局为:

  // (T8,V4) -> (m,k)
  using ALayout = Layout<Shape <Shape <_4,_2>,_4>,
                         Stride<Stride<_8,_4>,_1>>;

B布局相同。NN和TT转置则是我们已见两种布局的简单组合。

Hopper

现在我们可以来看Hopper架构首次引入的更大GMMA操作(Group MMA)。这些MMA指令以128线程(4个warp)为粒度运行,统称为warpgroup。

线程ID
Hopper GMMA中线程ID基于简单1D连续布局,因此ThrID非常简单:

using ThrID = Layout<_128, _1>;

累加器映射
GMMA中累加器分层映射,从核心矩阵开始逐步构建整个C矩阵tile。我们先看核心矩阵(这里只考虑fp16累加器,fp32扩展后面会看到,非常简单)。

每个核心矩阵布局如下图所示:

gmma_coremat_cd_fp16.png

和Volta例子一样,线程ID只是逻辑的,四个warp的归属不重要。

然后GMMA先沿M模式垂直平铺核心矩阵,再沿N模式重复该列核心矩阵,构成完整的MxN tile。平铺如下图:

gmma_wg_n_slice.png

有了这张图,我们就可以开始为SM90_64x128x16_F16F16F16F16_TN原子构建CLayout。同样,我们在(logical_thr_id, logical_val_id) -> (m, n)坐标空间建立映射。

先跟随前几个线程和值。我们立刻看到它们沿N模式以值对和4线程排列,于是:

// (T128,V4) -> (M64,N8)
using CLayout = Layout<Shape <Shape <  _4, ...>, Shape < _2, ...>>,
                       Stride<Stride<_128, ...>, Stride<_64, ...>>>;

完成第一个8x8核心矩阵:4个线程沿M模式重复8次:

// (T128,V4) -> (M64,N8)
using CLayout = Layout<Shape <Shape <  _4, _8, ...>, Shape < _2, ...>>,
                       Stride<Stride<_128, _1, ...>, Stride<_64, ...>>>;

下一个核心矩阵时回到T0,但这次是(T0, V2)

// (T128,V4) -> (M64,N8)
using CLayout = Layout<Shape <Shape <  _4, _8, ...>, Shape < _2, _2>>,
                       Stride<Stride<_128, _1, ...>, Stride<_64, _8>>>;

最后,整个模式在M模式重复4次(每个warp一次),从(m,n)=(16,0)=16开始。这使得thrID最后子模式的尺寸为4(4个warp),步长为16。于是最终64x8累加器的CLayout为:

// (T128,V4) -> (M64,N8)
using CLayout = Layout<Shape <Shape <  _4, _8,  _4>, Shape < _2, _2>>,
                       Stride<Stride<_128, _1, _16>, Stride<_64, _8>>>;//相当于1个图中那样的 8 列宽的块

GMMA指令包含64xN变体(N=[16,32,64,128,256]),该64x8模式重复,给每个线程更多值。从(m,n)=(0,8)=512开始,非常容易在CLayout中处理。例如64x128的CLayout为:

// (T128,V64) -> (M64,N128)
using CLayout = Layout<Shape <Shape <  _4, _8,  _4>, Shape < _2, _2,  _16>>,
                       Stride<Stride<_128, _1, _16>, Stride<_64, _8, _512>>>;//相当于16个图中那样的 8 列宽的块

A和B布局映射
直接从共享内存消费A和B的GMMA原子比较特别。GMMA Descriptor是在整个共享内存tile上构造的,而不是按线程分区。也就是说每个线程都能看到整个tile,tile不会被重排序以便构造Descriptor。在ALayout形式中可以表示为:

// (T128,V64x16) -> (M64,K16)
using ALayout = Layout<Shape <_128, Shape <_64,_16>>,
                       Stride<  _0, Stride< _1,_64>>>;//输入为(m, k),因此第二个shape需要二维

即所有线程都映射到(m,k)=(0,0)=0元素,值(和值的形状)保持不变。GMMA Descriptor构造器随后可以检查该数据的(M,K)布局,创建合适的Descriptor或报错。

TiledMMAs

TiledMMA 的出发点

在 CuTe 中,单个 MMA_Atom(比如 Volta 的 SM70_8x8x4_F32F16F16F32_NT)只能处理非常小的矩阵块(8×8×4)。但实际 GEMM 需要更大的 tile(比如 16×16×4、32×32×4),而且希望线程对数据的访问更“连续”(比如共享内存或寄存器中连续存放),这样才能获得更好的加载性能和更少的 bank conflict。

TiledMMA 的核心目的就是:

  1. 把多个 MMA_Atom 组合(replicate)成更大的计算单元;
  2. 通过 Layout of Atoms 决定这些原子在线程间的排列方式;
  3. 通过 Tile 参数 对 M/N/K 三个模式分别进行复制置换(permutation),让线程看到的数据布局更友好。

这就是 make_tiled_mma 函数的全部意义:把小原子“拼”成大图案,同时还能“重排”坐标


原理讲解

TiledMMA mma = make_tiled_mma(
    BaseAtom{},                    // 1. 基础原子(Operation + Traits)
    Layout_of_Atoms{},             // 2. 原子在线程间的排列(决定线程复制)
    Tile_MNK{}                     // 3. 对 M/N/K 模式分别做“复制或置换”
);
TiledMMA mma = make_tiled_mma(
    SM70_8x8x4_F32F16F16F32_NT{},
    Layout<Shape<_2,_2>, Stride<_2,_1>>{},           // 2×2 n-major 原子排列
    Tile<Layout<Shape<_4,_4,_2>, Stride<_1,_8,_4>>,  // ← 对 M 模式的置换
          _32,                                       // 对 N 模式的身份(不改变)
          _4>{}                                      // 对 K 模式的身份
);

这个 Layout<Shape<_4,_4,_2>, Stride<_1,_8,_4>> 到底在做什么?

  1. 先看形状Shape<_4,_4,_2> 说明最终 M 模式的大小是 4×4×2 = 32(和我们想要的 32×32×4 tile 匹配)。

  2. 再看步长(Stride)Stride<_1,_8,_4>置换规则

    • 第一个 _4 用 stride 1;
    • 第二个 _4 用 stride 8;
    • 第三个 _2 用 stride 4。
  3. 实际映射表(文档原表)

旧 m 坐标(0~31):  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
新 m 坐标(重排后): 0  1  2  3  8  9 10 11 16 17 18 19 24 25 26 27  4  5  6  7 12 13 14 15 20 21 22 23 28 29 30 31

直观效果

这就是置换的威力!它让后续共享内存布局或寄存器加载变得连续,极大提升性能。


举几个例子:

SM70_8x8x4_F32F16F16F32_NT为例:

MMA_Atom mma = MMA_Atom<SM70_8x8x4_F32F16F16F32_NT>{};
print_latex(mma);

HMMA.8x8x4.NT_Atom.png

上面等价于:

TiledMMA mma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{},
                              Layout<Shape<_1,_1,_1>>{},   // Atom布局
                              Tile<_8,_8,_4>{});           // Tiler

我们可以用4个这样的quadpair MMA创建一个类似WMMA的对象:

TiledMMA mma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{},
                              Layout<Shape <_2,_2>, Stride<_2,_1>>{});   // 2x2 n-major Atom布局

HMMA.8x8x4.NT_2x2.png

这个TiledMMAMMA_Atom跨线程复制(可以看到之前没用到的T4、T8、T12线程)。每个C矩阵象限都是新quadpair的原子分区模式的副本,复制遵循(2,2):(2,1)布局。

上面现在表示16x16x4 MMA,我们可以立即把“tile size”扩展到32x32x4:

TiledMMA mma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{},
                              Layout<Shape <_2,_2>, Stride<_2,_1>>{}, 
                              Tile<_32,_32,_4>{});

HMMA.8x8x4.NT_2x2_32x32x4.png

这个TiledMMA把前一个TiledMMA跨值(而非线程)复制。

继续,我们看到T0从A矩阵收到8个值,坐标分散。我们希望把它们排在一起,于是对M模式做置换:

TiledMMA mma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{},
                              Layout<Shape <_2,_2>, Stride<_2,_1>>{},
                              Tile<Layout<Shape <_4,_4,_2>, Stride<_1,_8,_4>>, _32, _4>{});

HMMA.8x8x4.NT_2x2_32Mx32x4.png

这个置换只影响M模式(A和C相应变化),让所有线程在A矩阵的m坐标上连续访问,非常适合共享内存或寄存器布局设计。

sgemm

矩阵乘法的两阶段执行:数据搬运与计算

矩阵乘法的执行分为**数据搬运(Load)矩阵计算(Compute)**两个阶段。值得注意的是,这两个阶段对线程的组织方式完全不同——同一批线程在不同阶段会以不同的布局方式参与工作。


问题规模与 CTA 划分

以一个 256×256 的输出矩阵 C 为例。我们规定每个 CTA(Block)负责 C 的 128×128 输出块的计算与写回,同时负责把对应的 A 128×32 和 B **128×32 **搬运到共享内存。因此,整个矩阵需要 (2, 2) = 4 个 CTA 协同完成全部计算。

对于每个 CTA,其所负责的 128×128 输出块决定了它需要从全局内存中读取的数据范围:


第一阶段:数据搬运(Load)

由于共享内存(Shared Memory)空间有限,无法一次性容纳完整的 128×32 数据块,因此将其沿 K 维度拆分为 4 个 128×8 的小块,每个小块称为一个 Tile。CTA 循环搬运 4 次,每次处理一个 Tile,共完成整段数据的搬运。

在此阶段,每个 Block 内开启 256 个线程,采用 32×8 的线程布局。对于每次单个 Tile(128×8)的搬运任务,线程与数据的对应关系为:每个线程负责搬运 (4, 1) 大小的数据片段。32×8 的线程网格覆盖整个 128×8 的 Tile,从而实现并行搬运。B 矩阵的搬运方式与此完全对称。


第二阶段:矩阵计算(Compute)

数据搬运完成后,进入计算阶段。此时线程布局切换为 16×16,与搬运阶段的布局完全不同。

在此布局下,CTA 负责的 128×128 输出块被均匀分配给 256 个线程,每个线程负责计算 8×8 大小的输出子块。

由于共享内存每次只能容纳一个 Tile 的数据,计算同样需要循环 4 次,与 Load 阶段的循环一一对应:

每次迭代 = Load 一个 Tile → 基于当前 Tile 计算 8×8 的部分结果 → 累加到线程私有的寄存器变量中

后续每次 Load 都会用新数据覆盖共享内存中上一轮的内容,但累加器中的中间结果已被安全保存在寄存器里。4 次循环结束后,每个线程的累加器中存储的便是其所负责的 8×8 输出块的最终结果(共 64 个值)。

最终,每个线程将自己独立计算得到的 8×8 结果写回全局内存,完成整个计算流程。

69db998aa8c197520b0b975b34c1d83a

local_tile

  // Represent the full tensors
  Tensor mA = make_tensor(make_gmem_ptr(A), select<0, 2>(shape_MNK), dA); // (M,K)
  Tensor mB = make_tensor(make_gmem_ptr(B), select<1, 2>(shape_MNK), dB); // (N,K)
  Tensor mC = make_tensor(make_gmem_ptr(C), select<0, 1>(shape_MNK), dC); // (M,N)

  // Get the appropriate blocks for this thread block
  auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
  Tensor gA      = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
  Tensor gB      = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
  Tensor gC      = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)

其中local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{});等价于:

Tensor gA = local_tile(mA, select<0,2>(cta_tiler), select<0,2>(cta_coord));//gA(128,8,512)
//                    (M,K)    (BLK_M, BLK_K)              (m, k)
// 含义:将 mA 按照 (BLK_M, BLK_K) 进行 zipped_divide,
// 再用 select<0,2>(cta_coord) = (m, _) 在rest-mode选取指定tile:
// 其中 m 指定当前 CTA 在 M 维度上的分块索引(固定),
// _ 表示 K 维度全选(保留所有分块),
// 最终结果的 shape 为 (BLK_M, BLK_K, k),
// 即当前 CTA 负责的 A 子块在 M 方向已定位,K 方向保留完整分块序列。

再具体点就是:

Tensor gA_mk = zipped_divide(mA, select<0,2>(cta_tiler));//((BLK_M, BLK_K), (m, k))
Tensor gA    = gA_mk(make_coord(_,_), select<0,2>(cta_coord));

local_partition

auto tA = make_layout(make_shape(Int<32>{}, Int<8>{})); // (m,k) -> thr_idx
Tensor tAgA = local_partition(gA, tA, threadIdx.x);//gA(BLK_M, BLK_K, k) tAgA(THR_M,THR_K,k)
// Partition sA (BLK_M, BLK_K) by the rows of tC
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)

其与local_tile的区别是:

具体过程

local_tile 解释

local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}) 的语义是:

对张量 mA(整个 A 矩阵)按照 select<0,2>(cta_tiler) 所指定的 tile 尺寸进行 zipped_divide,将矩阵在对应维度上切分为若干 tile 块;随后在2-mode索引维度上,根据 select<0,2>(cta_coord) 取出对应位置的子张量,从而返回当前 CTA 所负责的所有 tile。

以具体形状为例:zipped_divide 后张量变为 ((BLK_M, BLK_K), (m, k)),其中:

Step<_1, X, _1> 控制哪些维度参与切分(X 表示该维度不参与),最终通过 cta_coord 索引选出当前 CTA 对应的所有 tile。本质上,local_tile全局矩阵到 tile 粒度的划分,输入是整张矩阵。


local_partition 解释

local_partition(sA, tC, threadIdx.x, Step<_1, X>{}) 的语义是:

对张量 sA(一个 tile 大小的 shared memory 片段)按照 select<0>(tC) 所指定的线程布局尺寸进行 zipped_divide,将 tile 在对应维度上切分为若干线程负责的小块;随后将 threadIdx.x 映射到 tC 的多维线程坐标,并取其第 0 维的索引,从1-mode索引维度中选出当前线程负责的那部分数据。

以具体形状为例:zipped_divide 后张量变为 ((16), (THR_M, BLK_K)),其中:

Step<_1, X> 同样控制参与切分的维度,最终通过线程坐标索引取出当前线程在 tile 内应读取的数据。

local_partition(gA, tA, threadIdx.x) 的语义是:

对张量 gA(全局内存中当前 CTA 负责的所有 tile 窗口)按照 tA 所指定的线程布局尺寸进行 zipped_divide,将 tile 在对应维度上切分为若干线程负责的小块;随后将 threadIdx.x 映射到 tA 的多维线程坐标,从1-mode索引维度中选出当前线程负责的那部分数据。

以具体形状为例:zipped_divide 后张量变为 ((32, 8), (THR_M,THR_K,k)),其中:

本质上,local_partitiontile 内部到线程粒度的划分,输入是单个 tile,输出是每个线程独占的数据视图。


两者对比

local_tilelocal_partition
输入整张矩阵单个 tile(shared memory)
划分粒度Tile 级别(CTA 粒度)线程级别(Thread 粒度)
索引依据cta_coord(CTA 坐标)threadIdx.x(线程 ID)
输出当前 CTA 负责的 tile当前线程负责的数据片段

Cute支持将一个线程以多种布局方式来看待,从而能实现搬运和计算两种布局,来更方便的使用分配线程,其中由于线程是单维度定义,所以threadIdx.x就可以索引到所有线程

sgemm_1.cu最简实现

#include <cute/tensor.hpp>
using namespace cute;

__global__ void gemm_kernel(
  int M, int N, int K, float const* A, int ldA, float const* B, int ldB, float* C, int ldC, float alpha, float beta)
{
  // 1. 全尺寸 Tensor
  Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(1, ldA)); // (M,K) m-major
  Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(1, ldB)); // (N,K) n-major
  Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(1, ldC)); // (M,N)

  // 2. CTA 分块
  auto cta_tiler = make_shape(Int<128>{}, Int<128>{}, Int<8>{}); // (BLK_M, BLK_N, BLK_K)
  auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) 只取当前 CTA 的块
  Tensor gA      = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (128,8,k)
  Tensor gB      = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (128,8,k)
  Tensor gC      = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (128,128)

  // 3. SMEM 分配 + Tensor
  __shared__ float smemA[128 * 8];
  __shared__ float smemB[128 * 8];
  Tensor sA = make_tensor(make_smem_ptr(smemA), make_layout(make_shape(128, 8))); // m-major
  Tensor sB = make_tensor(make_smem_ptr(smemB), make_layout(make_shape(128, 8))); // n-major

  // 4. 线程分区
  auto tA = make_layout(make_shape(Int<32>{}, Int<8>{})); // 32×8 线程
  auto tB = make_layout(make_shape(Int<32>{}, Int<8>{}));
  auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // 16×16 线程

  Tensor tAgA = local_partition(gA, tA, threadIdx.x); // 每个线程拿 (4,1,k),只有带g的有整体tile"视角",即能取到所有批次的tile
  Tensor tAsA = local_partition(sA, tA, threadIdx.x);
  Tensor tBgB = local_partition(gB, tB, threadIdx.x);
  Tensor tBsB = local_partition(sB, tB, threadIdx.x);

  // 5. 计算分区
  Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (16,8)
  Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step<X, _1>{}); // (16,8)
  Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1, _1>{}); // (16,16)
  Tensor tCrC = make_tensor_like(tCgC); // 寄存器累加器

  clear(tCrC); // 清零

  // 6. Mainloop
  for (int k_tile = 0; k_tile < size<2>(tAgA); ++k_tile)
  {
    //相同布局视角才能互相copy,如tA才能相互copy,而不同的内存位置可以相互copy,只要布局视角相同就行,如gA和sA
    copy(tAgA(_, _, k_tile), tAsA); // gmem → smem
    copy(tBgB(_, _, k_tile), tBsB);
    cp_async_fence();
    cp_async_wait<0>();
    __syncthreads();
    //同时计算时也一定是同样布局视角tC
    gemm(tCsA, tCsB, tCrC); // SMEM GEMM
    __syncthreads();
  }

  // 7. Epilogue(axpby)
  axpby(alpha, tCrC, beta, tCgC);
}

// 主机端启动
void gemm_nt(
  int m, int n, int k, float alpha, float const* A, int ldA, float const* B, int ldB, float beta, float* C, int ldC)
{
  auto cta_tiler = make_shape(Int<128>{}, Int<128>{}, Int<8>{});
  dim3 grid(ceil_div(m, 128), ceil_div(n, 128));
  dim3 block(256); // 16×16=256 线程
  gemm_kernel<<<grid, block>>>(m, n, k, A, ldA, B, ldB, C, ldC, alpha, beta);
}

make_tiled_copymake_tiled_mma

sgemm_1 里,local_partition 已经能做两件事:

  1. tile 按线程划分,用于搬运
  2. tile 按线程划分,用于计算

也就是说,sgemm_1 的思路是:

先给一个线程布局 tA / tB / tC,再用 local_partition 把 tensor 切给线程。

这个方法足够简单,但是它有两个问题:

第一个问题:它只表达“线程怎么分”,不表达“指令怎么吃”

例如:

local_partition 本身并不关心这个,它只负责:

“把数据按线程布局切开”

它不负责:

“这个切法是否正好适配某条 copy 指令 / mma 指令”


第二个问题:搬运布局和计算布局,本来就不是一回事

例如同一个 CTA:

所以 sgemm_2 引入了两个新东西:

它们的本质都是:

不再只是“按线程切 tensor”,而是“按指令的吃法 + 线程布局”一起切 tensor。


make_tiled_copy 解释

先看这句:

TiledCopy copyA = make_tiled_copy(
  Copy_Atom<UniversalCopy<uint128_t>, TA>{},
  Layout<Shape<_32,_8>>{},
  Layout<Shape<_4,_1>>{}
);

这句不是在切 tensor。 它只是先构造出一个“搬运规则对象” copyA

所以你要先记住:

make_tiled_copy(...) 的输出不是某个子张量,而是一个“如何按 copy 指令去分块”的规则。


语义

make_tiled_copy(copy_atom, thr_layout, val_layout) 的语义是:

先定义一个最小搬运原子 copy_atom,再定义这个原子如何在线程维度值维度上重复铺开,从而形成一个 TiledCopy 分区器。这个分区器之后可以作用在 source tensor 和 destination tensor 上,分别得到当前线程该读什么、该写什么。


第一步:定义 copy atom

Copy_Atom<UniversalCopy<uint128_t>, TA>{}

它表示:

一条底层 copy 指令,按 uint128_t 这个粒度去搬数据。

如果 TA=float,那:

这一步的目的,是先规定:

单条搬运指令一次吃几个元素

这是 sgemm_1tA/tB 完全没有表达的东西。


第二步:定义线程布局

Layout<Shape<_32,_8>>{}

它表示:

在这套 copy 规则里,线程被看成一个 32 x 8 的二维布局。

也就是 256 个线程。

这一步的目的,是规定:

这些 copy atom 由哪些线程来执行,以及线程在 tile 中怎么排

它对应的,其实就是 sgemm_1 里的 tAtB,但这里只是“线程排布”这一层。


第三步:定义每个线程携带的值布局

Layout<Shape<_4,_1>>{}

它表示:

每个线程,不是只搬 1 个值,而是搬一个 4x1 的小块。

因为前面 copy atom 已经规定了一条指令一次搬 4 个 float,所以这里的 4x1 正好和它对上。

这一步的目的,是规定:

每个线程在一次 copy 里,拿哪几个逻辑元素

于是前两步合起来,就从“线程怎么排”上升成了:

线程怎么排,并且每个线程一次按什么打包方式搬数据


第四步:给当前线程取出自己的 copy 角色

ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);

这句的语义是:

从整套 TiledCopy 规则里,取出当前线程 threadIdx.x 对应的那一份局部规则。

它和 local_partition(..., threadIdx.x) 的共同点是:

但区别是:

这一步的目的,是把“整个 CTA 的 copy 规则”缩小到:

当前线程该怎么搬


第五步:把规则作用到 source / destination tensor 上

Tensor tAgA = thr_copy_a.partition_S(gA);   // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA);   // (CPY,CPY_M,CPY_K)
Tensor tArA = make_fragment_like(tAsA);     // (CPY,CPY_M,CPY_K)

这里才是真正把 tensor 切开。

partition_S 解释

partition_S(gA) 的语义是:

按照 TiledCopy 规定的 source 读法,把 gA 这个 source tensor 切给当前线程。

partition_D 解释

partition_D(sA) 的语义是:

按照 TiledCopy 规定的 destination 写法,把 sA 这个 destination tensor 切给当前线程。

这和 local_partition 最大的不同就在这:

因为有些硬件 copy 指令,读端和写端的模式未必相同。


具体形状

gA : (128, 8, k) 为例:

所以总覆盖范围是:

因此,对当前线程来说,partition_S(gA) 的结果可以具体看成:

tAgA ≈ (4, 1, 1, k)

其中:

对应地:

tAsA ≈ (4, 1, 1)
tArA ≈ (4, 1, 1)

也就是说:

每个线程每次从 gmem 读 4 个值,先放到寄存器,再写到 smem。


本质

所以 make_tiled_copy 的本质是:

先定义 copy 指令原子,再定义线程如何重复这个原子,最后分别对 source 和 destination 做分区。

它不是单纯“线程切块”,而是:

按 copy 指令的吃法来切块。


make_tiled_mma 解释

再看这句:

TiledMMA mmaC = make_tiled_mma(
  UniversalFMA<TC, TA, TB>{},
  Layout<Shape<_16,_16,_1>>{}
);

同样,这句也不是在切 tensor。 它只是先构造出一个“计算规则对象” mmaC

所以你也要先记住:

make_tiled_mma(...) 的输出不是某个子张量,而是一个“如何按 mma 指令去分块”的规则。


语义

make_tiled_mma(mma_atom, atom_layout) 的语义是:

先定义一个最小计算原子 mma_atom,再定义这个原子如何在 CTA 的 M/N/K 逻辑维度上重复铺开,从而形成一个 TiledMMA 分区器。这个分区器之后可以作用在 A、B、C tensor 上,得到当前线程在计算阶段该读取哪些 A/B 数据,以及负责哪些 C 累加器。


第一步:定义 mma atom

UniversalFMA<TC, TA, TB>{}

它表示:

最小计算原子是一条普通的 FMAc += a * b

所以它一次只消费:

也就是一个 1x1x1 的 mma atom

这一步的目的,是先规定:

单条计算指令一次吃几个 A/B/C 元素

这和前面的 copy atom 完全对应,只不过一个是“搬”,一个是“算”。


第二步:定义 atom 的铺排方式

Layout<Shape<_16,_16,_1>>{}

它表示:

把这个 1x1x1 的计算原子,在逻辑上按 16 x 16 x 1 的方式重复排列。

因为 16 * 16 * 1 = 256,所以正好对应一个 block 里的 256 个线程。

这一步的目的,是规定:

哪些线程执行哪些 mma atom,以及这些 atom 在 M/N/K 上怎么铺开

它对应的,其实就是 sgemm_1 里的 tC,但这里不是简单线程布局,而是:

线程 + mma atom 的联合布局


第三步:给当前线程取出自己的 mma 角色

ThrMMA thr_mma = mma.get_slice(threadIdx.x);

它的语义是:

从整套 TiledMMA 规则里,取出当前线程 threadIdx.x 对应的那一份局部 mma 规则。

它和前面的 ThrCopy 完全平行:


第四步:把规则作用到 A / B / C tensor 上

Tensor tCsA = thr_mma.partition_A(sA);   // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB);   // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC);   // (MMA,MMA_M,MMA_N)
Tensor tCrC = thr_mma.make_fragment_C(tCgC);

这里才是真正把 tensor 切开。

partition_A 解释

partition_A(sA) 的语义是:

按照 TiledMMA 规定的 A 操作数读法,把 sA 切给当前线程。

partition_B 解释

partition_B(sB) 的语义是:

按照 TiledMMA 规定的 B 操作数读法,把 sB 切给当前线程。

partition_C 解释

partition_C(gC) 的语义是:

按照 TiledMMA 规定的 C 输出/累加器布局,把 gC 切给当前线程。

这里和 TiledCopy 的区别很明显:


具体形状

以 CTA tile 为:

TiledMMA 的 atom 布局是 (16,16,1)

那么逻辑上:

又因为 atom 本身是 1x1x1,所以第一维 MMA = 1

于是对当前线程来说,可以把结果具体看成:

tCsA ≈ (1, 8, 8)
tCsB ≈ (1, 8, 8)
tCgC ≈ (1, 8, 8)
tCrC ≈ (1, 8, 8)

也就是说:

这其实和 sgemm_1tC 分出来的 (8,8) 很像,只不过现在最前面多了一层:

“单条 mma 指令的一次消费粒度”


本质

所以 make_tiled_mma 的本质是:

先定义 mma 指令原子,再定义线程如何重复这个原子,最后分别对 A/B/C 做分区。

它不是单纯“线程切块”,而是:

按 mma 指令的吃法来切块。


两者对比

make_tiled_copymake_tiled_mma
解决的问题搬运阶段如何按 copy 指令组织线程和数据计算阶段如何按 mma 指令组织线程和数据
构造输入copy_atom + thr_layout + val_layoutmma_atom + atom_layout
得到的对象TiledCopy 规则对象TiledMMA 规则对象
中间步骤get_slice(threadIdx.x) 得到 ThrCopyget_slice(threadIdx.x) 得到 ThrMMA
作用对象partition_S / partition_Dpartition_A / partition_B / partition_C
第一维含义CPY:单条 copy 指令一次消费的元素组MMA:单条 mma 指令一次消费的元素组
sgemm_1 中替代谁替代 tA / tB + local_partition 那套搬运分区替代 tC + local_partition 那套计算分区
本质按搬运指令切分按计算指令切分

make_tiled_copy / make_tiled_mma 的本质

它们不是直接替代 local_partition 这个函数本身,而是把“分区规则”升级了。

也就是:

Tensor tAgA = local_partition(gA, tA, threadIdx.x);//(THR_M,THR_K,k)
Tensor tAsA = local_partition(sA, tA, threadIdx.x);//(THR_M,THR_K)

含义是:

用普通线程布局 tAgAsA


ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA);// (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA);// (CPY,CPY_M,CPY_K)

含义是:

用“copy 指令感知”的规则切 gAsA

差别就在于:

sgemm_2 多显式保留了一层 CPY,表示“单条 copy 指令的一次消费组”


Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{});// (THR_M,BLK_K)
Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step<X, _1>{});// (THR_N,BLK_K)
Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1, _1>{});// (THR_M,THR_N)

含义是:

用普通线程布局 tC 切 A/B/C


ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA);// (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB);// (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC);// (MMA,MMA_M,MMA_N)

含义是:

用“mma 指令感知”的规则切 A/B/C

差别就在于:

sgemm_2 多显式保留了一层 MMA,表示“单条 mma 指令的一次消费组”


sgemm_2.cu最简实现

#include <cute/tensor.hpp>
using namespace cute;

__global__ void gemm_kernel(
  int M, int N, int K,
  float const* A, int ldA,
  float const* B, int ldB,
  float* C, int ldC,
  float alpha, float beta)
{
  // 1. 全尺寸 Tensor(和 sgemm_1 一样)
  Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(1, ldA)); // (M,K) m-major
  Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(1, ldB)); // (N,K) n-major
  Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(1, ldC)); // (M,N)

  // 2. CTA 分块(和 sgemm_1 一样)
  auto cta_tiler = make_shape(Int<128>{}, Int<128>{}, Int<8>{}); // (BLK_M, BLK_N, BLK_K)
  auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);        // (m,n,k)

  Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (128,8,k)
  Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (128,8,k)
  Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (128,128)

  // 3. SMEM 分配 + Tensor
  // 为了方便和 sgemm_1 对比,这里先故意不改成 padded layout
  __shared__ float smemA[128 * 8];
  __shared__ float smemB[128 * 8];

  Tensor sA = make_tensor(make_smem_ptr(smemA), make_layout(make_shape(Int<128>{}, Int<8>{}))); // (128,8)
  Tensor sB = make_tensor(make_smem_ptr(smemB), make_layout(make_shape(Int<128>{}, Int<8>{}))); // (128,8)

  // 4. 用 TiledCopy 取代 tA / tB
  //
  // 含义:
  // - 线程仍然是 32x8 = 256 个
  // - 每个线程搬 4x1 个 float
  // - copy atom 规定一次按 uint128_t = 128bit 去搬,也就是 4 个 float
  TiledCopy copyA = make_tiled_copy(
      Copy_Atom<UniversalCopy<uint128_t>, float>{},
      Layout<Shape<_32, _8>>{},   // 线程布局
      Layout<Shape<_4,  _1>>{});  // 每线程值布局

  TiledCopy copyB = make_tiled_copy(
      Copy_Atom<UniversalCopy<uint128_t>, float>{},
      Layout<Shape<_32, _8>>{},
      Layout<Shape<_4,  _1>>{});

  // 当前线程在 copyA / copyB 里的角色
  ThrCopy thr_copy_a = copyA.get_slice(threadIdx.x);
  ThrCopy thr_copy_b = copyB.get_slice(threadIdx.x);

  // A/B 的 source / destination 分区
  Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) = 约 (4,1,1,k)
  Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)   = 约 (4,1,1)
  Tensor tArA = make_fragment_like(tAsA);   // 寄存器 staging buffer

  Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) = 约 (4,1,1,k)
  Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)   = 约 (4,1,1)
  Tensor tBrB = make_fragment_like(tBsB);   // 寄存器 staging buffer

  // 5. 用 TiledMMA 取代 tC
  //
  // 这里用最简单的 UniversalFMA<float,float,float>
  // 它是 1x1x1 的 mma atom
  // 再铺成 16x16x1 的 TiledMMA
  TiledMMA mma = make_tiled_mma(
      UniversalFMA<float, float, float>{},
      Layout<Shape<_16, _16, _1>>{});

  ThrMMA thr_mma = mma.get_slice(threadIdx.x);

  Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) = 约 (1,8,8)
  Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) = 约 (1,8,8)
  Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) = 约 (1,8,8)

  // 和 sgemm_1 不同:这里不是 make_tensor_like,而是 make_fragment_C
  Tensor tCrC = thr_mma.make_fragment_C(tCgC); // 寄存器累加器

  clear(tCrC);

  // 6. Mainloop
  //
  // 和 sgemm_1 最大的区别:
  // 先 gmem -> rmem 预取第 0 块
  copy(copyA, tAgA(_, _, _, 0), tArA);
  copy(copyB, tBgB(_, _, _, 0), tBrB);

  int K_TILE_MAX = size<3>(tAgA);
  // 加了一个“寄存器中转层”tArA,tBrB,当前块在 smem 里计算时,下一块已经开始从 gmem 读到寄存器里了,实现搬运和计算并行
  for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
  {
    // (1) 先把“已经预取到寄存器的当前块”写进 smem
    __syncthreads();      // 等上一轮所有线程都读完旧 smem
    copy(tArA, tAsA);     // rmem -> smem
    copy(tBrB, tBsB);     // rmem -> smem
    __syncthreads();      // 等这一轮所有线程都写完新 smem

    // (2) 同时预取下一块到寄存器
    int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? (k_tile + 1) : k_tile;
    copy(copyA, tAgA(_, _, _, k_tile_next), tArA); // gmem -> rmem
    copy(copyB, tBgB(_, _, _, k_tile_next), tBrB); // gmem -> rmem

    // (3) 在 smem 上做 GEMM
    gemm(mma, tCsA, tCsB, tCrC);
  }

  // 7. Epilogue(和 sgemm_1 一样,都是把寄存器累加器写回)
  axpby(alpha, tCrC, beta, tCgC);
}

// 主机端启动
void gemm_nt(
  int m, int n, int k,
  float alpha,
  float const* A, int ldA,
  float const* B, int ldB,
  float beta,
  float* C, int ldC)
{
  dim3 grid(ceil_div(m, 128), ceil_div(n, 128));
  dim3 block(256); // 仍然是 256 线程
  gemm_kernel<<<grid, block>>>(m, n, k, A, ldA, B, ldB, C, ldC, alpha, beta);
}

TiledMMA 分区:A/B/C 的计算视图

这句也和 sgemm_2 一样:

ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA    = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB    = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC    = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)

这里最关键的是最后一维:

MMA_K 不是“元素个数”,而是“当前 tile 在 K 方向上被分成了多少个小计算波次”


Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)

你可以把 sgemm_sm70 的数据路径写成:

gmem -> rmem(copy staging) -> smem -> rmem(mma fragment) -> gemm -> rmem(accumulator)

具体对应:

sgemm_sm70最简实现

#include <cute/tensor.hpp>
using namespace cute;

__global__ __launch_bounds__(256) void gemm_kernel_sm70(
  int M, int N, int K,
  float const* A, int ldA,
  float const* B, int ldB,
  float* C, int ldC,
  float alpha, float beta)
{
  // 1. 全尺寸 Tensor
  Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(1, ldA)); // (M,K) m-major
  Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(1, ldB)); // (N,K) n-major
  Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(1, ldC)); // (M,N)

  // 2. CTA 分块
  auto cta_tiler = make_shape(Int<128>{}, Int<128>{}, Int<8>{}); // (BLK_M, BLK_N, BLK_K)
  auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);        // (m,n,k)

  Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (128,8,k)
  Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (128,8,k)
  Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (128,128)

  // 3. SMEM
  __shared__ float smemA[128 * 8];
  __shared__ float smemB[128 * 8];

  Tensor sA = make_tensor(make_smem_ptr(smemA), make_layout(make_shape(Int<128>{}, Int<8>{}))); // (128,8)
  Tensor sB = make_tensor(make_smem_ptr(smemB), make_layout(make_shape(Int<128>{}, Int<8>{}))); // (128,8)

  // 4. TiledCopy(NT)
  TiledCopy copyA = make_tiled_copy(
      Copy_Atom<UniversalCopy<uint128_t>, float>{},
      Layout<Shape<_32, _8>>{},   // Thr layout 32x8 m-major
      Layout<Shape<_4,  _1>>{});  // Val layout  4x1 m-major

  TiledCopy copyB = make_tiled_copy(
      Copy_Atom<UniversalCopy<uint128_t>, float>{},
      Layout<Shape<_32, _8>>{},   // Thr layout 32x8 n-major
      Layout<Shape<_4,  _1>>{});  // Val layout  4x1 n-major

  ThrCopy thr_copy_a = copyA.get_slice(threadIdx.x);
  Tensor tAgA        = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
  Tensor tAsA        = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
  Tensor tArA        = make_fragment_like(tAsA);   // (CPY,CPY_M,CPY_K)

  ThrCopy thr_copy_b = copyB.get_slice(threadIdx.x);
  Tensor tBgB        = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
  Tensor tBsB        = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
  Tensor tBrB        = make_fragment_like(tBsB);   // (CPY,CPY_N,CPY_K)

  // 5. 先预取第 0 个 k_tile:gmem -> rmem
  copy(copyA, tAgA(_, _, _, 0), tArA);
  copy(copyB, tBgB(_, _, _, 0), tBrB);

  // 6. TiledMMA
  TiledMMA mma = make_tiled_mma(
      UniversalFMA<float, float, float>{},
      Layout<Shape<_16, _16, _1>>{}); // 16x16x1

  ThrMMA thr_mma = mma.get_slice(threadIdx.x);
  Tensor tCsA    = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
  Tensor tCsB    = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
  Tensor tCgC    = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)

  // 7. SM70 新增:计算输入 fragment + 累加器
  Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
  Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
  Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)

  clear(tCrC);

  // 8. 先把当前大块写进 smem:rmem -> smem
  copy(tArA, tAsA);
  copy(tBrB, tBsB);
  __syncthreads();

  // 9. 先加载 k_block = 0:smem -> reg(fragment)
  copy(tCsA(_, _, 0), tCrA(_, _, 0));
  copy(tCsB(_, _, 0), tCrB(_, _, 0));

  int K_TILE_MAX  = size<3>(tAgA);
  int K_BLOCK_MAX = size<2>(tCrA);

  // 10. PIPELINED MAIN LOOP
  for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
  {
    for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
    {
      if (k_block == K_BLOCK_MAX - 1)
      {
        // 这次tile快结束的时候,下一整个 k_tile:rmem -> smem
        __syncthreads();
        copy(tArA, tAsA);
        copy(tBrB, tBsB);
        __syncthreads();
      }

      // 下一波 k_block:smem -> reg(fragment)
      int k_block_next = (k_block + 1) % K_BLOCK_MAX;
      copy(tCsA(_, _, k_block_next), tCrA(_, _, k_block_next));
      copy(tCsB(_, _, k_block_next), tCrB(_, _, k_block_next));

      if (k_block == 0)
      {
        // 这次tile刚开始的时候,load下一整个 k_tile:gmem -> rmem,注意tCrA和tArA不是一块空间
        int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? (k_tile + 1) : k_tile;
        copy(copyA, tAgA(_, _, _, k_tile_next), tArA);
        copy(copyB, tBgB(_, _, _, k_tile_next), tBrB);
      }

      // 当前波次:reg(fragment) -> gemm
      gemm(mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
    }
  }

  // 11. Epilogue
  axpby(alpha, tCrC, beta, tCgC);
}

void gemm_nt_sm70(
  int m, int n, int k,
  float alpha,
  float const* A, int ldA,
  float const* B, int ldB,
  float beta,
  float* C, int ldC)
{
  dim3 grid(ceil_div(m, 128), ceil_div(n, 128));
  dim3 block(256);
  gemm_kernel_sm70<<<grid, block>>>(m, n, k, A, ldA, B, ldB, C, ldC, alpha, beta);
}

对于TiledCopy和ThrCopy,对于TiledMMA和ThrCopy

关于两者,查看源码可知,其拥有的常用函数功能如下:

template <class MMA_Atom, class AtomLayoutMNK, class PermutationMNK = Tile<Underscore, Underscore, Underscore>>
struct TiledMMA : MMA_Atom
{
  // ===================================================================
  // get_slice / get_thread_slice
  // 出发点:把 threadIdx.x 映射成 TiledMMA 内部的线程坐标 (ThrV, ThrM, ThrN, ThrK)
  // 原理:使用 thr_layout_vmnk_ 把线程索引转成扁平坐标,返回一个 ThrMMA 对象
  // 作用:后续所有 partition_* 都依赖这个“线程切片对象”
  // 与 local_partition 的区别:这是 CuTe 推荐的现代方式,会自动生成 (thr, val) 分布
  // 例子:在 sgemm_2.cu 中写 auto thr_mma = mma.get_thread_slice(threadIdx.x);
  // ===================================================================
  template <class ThrIdx, __CUTE_REQUIRES(is_integral<ThrIdx>::value)>
  CUTE_HOST_DEVICE constexpr auto get_slice(ThrIdx const& thr_idx) const;

  // get_thread_slice 和 get_slice 完全一样(只是命名习惯统一)
  // 文档里两个名字都保留,是为了和 TiledCopy 写法一致
  template <class ThrIdx, __CUTE_REQUIRES(is_integral<ThrIdx>::value)>
  CUTE_HOST_DEVICE constexpr auto get_thread_slice(ThrIdx const& thr_idx) const;
};

template <class TiledMMA, class ThrVMNK>
struct ThrMMA : TiledMMA
{
  // ===================================================================
  // partition_C / partition_A / partition_B
  // 出发点:把已经 tiled 的全局/共享内存 Tensor 按当前线程切成 subtensor
  // 原理:内部调用 thrfrg_C/A/B(TiledMMA 的核心函数),生成 (Thr, Val) 分布
  // 作用:返回该线程负责的数据视图(可直接用于 gemm / copy)
  // 与 local_partition 的区别:这是真正的 (thr, val) 布局,支持任意 MMA 指令
  // 例子:Tensor tCgC = thr_mma.partition_C(gC);   // (MMA, MMA_M, MMA_N)
  // ===================================================================
  template <class CTensor>
  CUTE_HOST_DEVICE constexpr auto partition_C(CTensor&& ctensor) const;

  template <class ATensor>
  CUTE_HOST_DEVICE constexpr auto partition_A(ATensor&& atensor) const;

  template <class BTensor>
  CUTE_HOST_DEVICE constexpr auto partition_B(BTensor&& btensor) const;

  // ===================================================================
  // partition_fragment_C / partition_fragment_A / partition_fragment_B
  // 出发点:把 partition_* 返回的共享/全局视图转成寄存器 fragment
  // 原理:调用 make_fragment_C/A/B(MMA_Atom 提供),生成寄存器友好的 FrgTypeC/FrgTypeA
  // 作用:专门给 gemm 使用,寄存器重用最高、支持 vectorization
  // 与 partition_* 的区别:一个是内存视图,一个是寄存器 fragment
  // 例子:Tensor tCrC = thr_mma.partition_fragment_C(tCgC);
  // ===================================================================
  template <class CTensor>
  CUTE_HOST_DEVICE constexpr auto partition_fragment_C(CTensor&& ctensor) const;

  template <class ATensor>
  CUTE_HOST_DEVICE constexpr auto partition_fragment_A(ATensor&& atensor) const;

  template <class BTensor>
  CUTE_HOST_DEVICE constexpr auto partition_fragment_B(BTensor&& btensor) const;
};

template <class Copy_Atom,
          class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
          class ShapeTiler_MN> // coord space
struct TiledCopy : Copy_Atom
{
  // ===================================================================
  // get_slice / get_thread_slice
  // 出发点:同 TiledMMA,获取当前线程的 ThrCopy 切片对象
  // 原理:把 threadIdx.x 映射到 TiledLayout_TV 的坐标
  // 作用:后续 partition_S/D 和 retile 都依赖这个对象
  // 例子:在 sgemm_2.cu 中写 auto thr_copy_a = copyA.get_thread_slice(threadIdx.x);
  // ===================================================================
  template <class ThrIdx, __CUTE_REQUIRES(is_integral<ThrIdx>::value)>
  CUTE_HOST_DEVICE static auto get_slice(ThrIdx const& thr_idx);

  template <class ThrIdx, __CUTE_REQUIRES(is_integral<ThrIdx>::value)>
  CUTE_HOST_DEVICE static auto get_thread_slice(ThrIdx const& thr_idx);
};

template <class TiledCopy, class ThrIdx>
struct ThrCopy
{
  // ===================================================================
  // partition_S / partition_D
  // 出发点:把全局/共享内存 Tensor 按当前线程切成 subtensor
  // 原理:内部调用 tidfrg_S/D,生成 (Thr, (FrgV, FrgX), Rest...) 布局
  // 作用:S = Source(读),D = Destination(写),支持 cp.async 等不同指令
  // 与 local_partition 的区别:这是 CuTe 推荐的 (thr, val) 方式
  // 例子:Tensor tAgA = thr_copy_a.partition_S(gA);
  // ===================================================================
  template <class STensor>
  CUTE_HOST_DEVICE auto partition_S(STensor&& stensor) const;

  template <class DTensor>
  CUTE_HOST_DEVICE auto partition_D(DTensor&& dtensor) const;

  // ===================================================================
  // retile_S / retile_D
  // 出发点:把 partition_* 返回的 (Thr, Val) 布局重排成寄存器友好形状 (Val, Thr, Rest...)
  // 原理:使用 right_inverse(TiledLayout_TV) 把 Thr 模式压到后面,让 FrgV 连续
  // 作用:寄存器 staging 时必须调用,支持 vectorization 和循环展开
  // 什么时候用?只要要把 smem 数据拷贝到寄存器(tArA、tBrB)就必须 retile
  // 与 partition_* 的区别:partition 是线程视图,retile 是寄存器视图
  // 例子:Tensor tArA = ThrCopy::retile_S(tAsA);
  // ===================================================================
  template <class STensor>
  CUTE_HOST_DEVICE static auto retile_S(STensor&& stensor);

  template <class DTensor>
  CUTE_HOST_DEVICE static auto retile_D(DTensor&& dtensor);
};

TiledMMA 部分

1. get_slice(ThrIdx)get_thread_slice(ThrIdx)

// TiledMMA::get_slice
auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx);
return ThrMMA<TiledMMA, decltype(thr_vmnk)>{*this, thr_vmnk};

// TiledMMA::get_thread_slice
return get_slice(thr_idx);   // 完全相同!

区别完全一样(文档里故意写两个名字)。

原理:把线程索引 threadIdx.x 映射到 TiledMMA 的内部坐标 (ThrV, ThrM, ThrN, ThrK)thr_layout_vmnk_tiled_product(AtomThrID, AtomLayoutMNK) 生成的)。

例子sgemm_sm70.cu):

TiledMMA mmaC = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, Layout<Shape<_2,_2>>{});
auto thr_mma = mmaC.get_thread_slice(threadIdx.x);  // 每个线程拿到自己的 ThrMMA 切片

2. partition_C / partition_A / partition_B(ThrMMA 中的方法)

// 以 partition_C 为例(其余类似)
auto thr_tensor = make_tensor(ctensor.data(), this->thrfrg_C(ctensor.layout()));
auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_)));
return thr_tensor(thr_vmn, make_coord(_, repeat<rank<1,1>(thr_tensor)>(_)));

原理(三步):

  1. thrfrg_C(M,N)((ThrV,(ThrM,ThrN)), (FrgV,(RestM,RestN)))0t_mma_atom.md 图示);
  2. thr_vmnk 取出当前线程(ThrV, ThrM, ThrN)
  3. 返回该线程专属的 subtensor(寄存器视图)。

区别

例子sgemm_2.cu):

Tensor tCsA = thr_mma.partition_A(sA);   // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB);   // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC);   // (MMA,MMA_M,MMA_N)

3. partition_fragment_C / partition_fragment_A / partition_fragment_B

template <class CTensor>
auto partition_fragment_C(CTensor&& ctensor) const
{
  return TiledMMA::make_fragment_C(partition_C(ctensor));
}

区别

原理make_fragment_C 会根据 CLayout(MMA,MMA_M,MMA_N) 转成寄存器数组(例如 float[8]),避免 bank conflict。

例子sgemm_2.cu):

Tensor tCrC = thr_mma.partition_fragment_C(tCgC);  // 寄存器 fragment
cute::gemm(mma, tCsA, tCsB, tCrC);                 // 直接用 fragment

TiledCopy 部分

1. get_sliceget_thread_slice

和 MMA 完全一样(源码里 get_thread_slice 直接调用 get_slice),只是返回 ThrCopy 对象。

例子

TiledCopy copyA = make_tiled_copy(...);
auto thr_copy_a = copyA.get_thread_slice(threadIdx.x);

2. partition_S / partition_D(ThrCopy 中的方法)

// partition_S
auto thr_tensor = make_tensor(stensor.data(), TiledCopy::tidfrg_S(stensor.layout()));
return thr_tensor(thr_idx_, _, repeat<rank_v<STensor>>(_));

// partition_D 同理,用 tidfrg_D

原理(与 MMA 类似,但有 Src/Dst 区别):

区别S = Source(读),D = Destination(写)。
这是因为 Copy_AtomValLayoutSrcValLayoutDst 可以不一样(UniversalCopy 相同,cp.async 不同)。

例子sgemm_2.cu):

Tensor tAgA = thr_copy_a.partition_S(gA);  // source(gmem)
Tensor tAsA = thr_copy_a.partition_D(sA);  // dest(smem)
cute::copy(copyA, tAgA, tAsA);

3. retile_S / retile_D(静态方法)

// 两者完全相同(源码里 retile 是静态的)
auto frg_layout_v = ...;   // 把 (ThrV, FrgV) 转成寄存器友好的布局,原来的布局是线程优先(Thr, (FrgV, FrgX), Rest...)
						   //retile 把它翻转成值优先(FrgV, (Thr, FrgX), Rest...),方便读取
return make_tensor(tensor.data(), TiledCopy::retile(tensor.layout()));

区别无区别(只是为了命名对称)。
retile寄存器重排操作,把 (Thr, Val) 转成 (Val, Rest),方便寄存器循环展开和 vectorization。

原理retileupcast + right_inverse(TiledLayout_TV) 把线程布局“压平”,让每个线程的 FrgV 连续存储。

例子

Tensor tArA = ThrCopy::retile_S(tAsA);  // 把 smem 视图转成寄存器连续布局
copy(copyA, tAgA, tArA);

下面是只修正有问题的描述后的版本,结构不变


local_tilelocal_partition的区别

local_tilelocal_partition 本身并不直接引入 Copy/MMA 指令语义下的 (thr, val) 分布

它们只是通用切片工具

它们不会自动引入 CopyAtom / MMAAtom 那种“单条指令吃几个值”的语义,也不会自动生成 FrgV 这一层。


真正产生 (thr, val) 分布的函数

只有下面这些函数才会在 TiledCopy / TiledMMA 这套语义里,把普通 tensor 重新解释成带线程-值分工的布局

函数所属对象产生的布局结构文档出处
thrfrg_C / thrfrg_A / thrfrg_BTiledMMA((ThrV, (ThrM, ThrN)), (FrgV, ...))0t_mma_atom.md
tidfrg_S / tidfrg_DTiledCopy(Thr, (FrgV, FrgX), ...)TiledCopy 源码
partition_C / partition_A / partition_BThrMMA内部先用 thrfrg_* 构造线程-值布局,再按当前线程切片0x_gemm_tutorial.md
partition_S / partition_DThrCopy内部先用 tidfrg_* 构造线程-值布局,再按当前线程切片sgemm_2.cu
get_layoutC_TVTiledMMA/Copy返回纯 (thr, val) → coord 映射,本身不是对某个 tensor 做切片0t_mma_atom.md

这些才是 CuTe 里面带有 Copy/MMA 指令分工语义的现代写法


sgemm_1.cu vs sgemm_2.cu 对比(最直观)

sgemm_1.cu(老式,手动)

auto tC = make_layout(Shape<_16,_16>{});
Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{});
Tensor tCrC = make_tensor_like(tCgC);        // 手动
cute::gemm(tCsA, tCsB, tCrC);                // 只能用 UniversalFMA

sgemm_2.cu(现代,推荐)

TiledMMA mma = make_tiled_mma(UniversalFMA<...>{}, Layout<Shape<_16,_16,_1>>{});

auto thr_mma = mma.get_thread_slice(threadIdx.x);
Tensor tCgC = thr_mma.partition_C(gC);            // ← 内部先用 thrfrg_C 构造布局,再切当前线程
Tensor tCrC = thr_mma.partition_fragment_C(tCgC); // ← 进一步得到适合 MMA 的 fragment
cute::gemm(mma, tCsA, tCsB, tCrC);

区别

CuTe TMA Tensor 深度解析


1. 发明背景:TMA 指令为什么需要新的 Tensor 类型

要理解 TMA Tensor 为什么被发明,必须先搞清楚 TMA 指令的本质与普通 CUDA 内存操作的根本区别。

1.1 普通 CUDA 内存操作的工作方式

在传统 CUDA kernel 中,数据搬运是以线程为单位进行的:每条线程自己算出要读/写的元素地址,发出 load/store 指令。CuTe 的普通 Tensor 正是为这种模式量身定制的:

// 普通 CuTe Tensor 的工作模式
Tensor gmem_tensor = make_tensor(ptr, make_shape(M, N), make_stride(1, M));
// 访问 (i, j) 处的元素:
//   layout 计算线性 offset  ->  i*1 + j*M
//   iterator 前进该 offset  ->  ptr + i*1 + j*M
//   最终取得 GMEM 上的真实指针

这整套流程的核心是:逻辑坐标 → 线性地址偏移 → 真实指针

1.2 TMA 指令的工作方式截然不同

TMA(Tensor Memory Accelerator)是 Hopper(SM90)架构上的一组硬件指令,专门负责在 Global Memory 和 Shared Memory 之间高效搬运整块 tile 数据,不再需要每个线程逐元素发指令。

TMA 的关键在于它依赖一个提前在 host 端构造好的 TMA descriptor。这个 descriptor 是一个打包好的描述符,里面包含:

真正在 kernel 中执行 TMA 指令时,你给它的参数是:descriptor 指针、smem 指针、以及 GMEM 视角下的多维坐标

⚠️ 关键区别:TMA 指令不接受 GMEM pointer!

它接受的是「多维坐标」(例如在一个 128×64 的矩阵里,坐标是 (row, col)),再配合 descriptor 里已经存好的基地址和 stride,由硬件自己算出最终地址并完成搬运。

1.3 矛盾在哪里

这就产生了一个根本矛盾:

维度普通 CuTe TensorTMA 需要的
存储内容GMEM 指针不需要 GMEM 指针
layout 输出线性地址偏移(整数)多维坐标(tuple)
最终传给指令真实内存地址坐标 + descriptor
适合的访问模式per-thread 逐元素访问整 tile 批量搬运

简言之:普通 CuTe Tensor 对 TMA 来说是没用的——它输出的是线性地址,而 TMA 根本不吃这个。所以 CuTe 必须发明一种新的 Tensor,使其能直接输出 TMA 所需的多维坐标。这就是本文档的全部目的。


2. 第一步过渡:Tensor 的迭代器不一定是指针

在直接引入 TMA Tensor 之前,文档先做了一个思想热身——打破「Tensor 一定绑定真实内存」的固有印象。

2.1 CuTe Tensor 的本质结构

CuTe Tensor 从设计上看,其实是两个部分的组合:

普通 GMEM Tensor 的 iterator 是一个真实的 GPU 指针,index 它就能得到内存里的元素。但 iterator 的本质要求并不是「必须是指针」,而只需要满足:「能被 layout 计算出的 index 偏移,得到某个结果」

2.2 counting_iterator 的例子

文档用 counting_iterator 作为第一个热身例子:

Tensor A = make_tensor(counting_iterator<int>(42), make_shape(4,5));

// 打印结果:
// counting_iter(42) o (4,5):(_1,4):
//    42   46   50   54   58
//    43   47   51   55   59
//    44   48   52   56   60
//    45   49   53   57   61

这个 Tensor 没有任何真实内存。它的含义是:

所以 Tensor[1][2] = 42 + 1 + 2*4 = 51。这个值不是从内存里读来的,而是「根据规则现算出来的」。

📌 这一步的意义

证明了:CuTe Tensor 的 data 部分不一定是真实内存,也可以是「隐式生成的值」。最重要的是,这种 Tensor 仍然可以像普通 Tensor 一样被 tile、partition、slice——CuTe 的整套工具链对它完全适用。这为后续推广到「坐标 Tensor」铺平了道路。


3. 第二步过渡:从「生成整数」到「生成坐标」

文档接着问了一个关键问题:

🤔 counting_iterator 能生成「整数」;那我们能不能造一个类似的东西,让它生成「TMA 坐标」?

如果可以,那对这个「坐标 Tensor」做 tile/partition/slice,得到的就是每个分块对应的 TMA 坐标——可以直接喂给 TMA 指令,完全不需要额外计算。

这就是整篇文档的核心主线,所有后续的机制都是为了实现这一点而引入的。

维度普通 TensorTMA Tensor(目标)
iterator 里存的GMEM 指针当前 TMA 坐标
layout 输出线性 offset(整数)TMA 坐标增量(tuple)
最终结果指向某元素的地址能传给 TMA 的多维坐标
对 tile/partition 的支持✅(目标)

4. 构造坐标迭代器:ArithmeticTuple 与 ArithmeticTupleIterator

4.1 ArithmeticTuple 是什么

普通的 tuple 只是一个「容器」,把几个值装在一起。ArithmeticTuple 在此基础上额外实现了代数运算——最关键的是 operator+,使得两个 tuple 可以逐元素相加:

// 概念示意
ArithmeticTuple a = (42, 2, 7);
ArithmeticTuple b = (0, 5, 2);
ArithmeticTuple c = a + b;  // -> (42, 7, 9)

// 普通 tuple 不支持 + 操作;
// ArithmeticTuple 支持,且按位(逐元素)相加。

4.2 ArithmeticTupleIterator 是什么

ArithmeticTupleIterator 就是「counting_iterator 的坐标版本」:

文档给出的例子:

ArithmeticTupleIterator citer_1 = make_inttuple_iter(42, Int<2>{}, Int<7>{});
// citer_1 内部存着坐标:(42, 2, 7)

ArithmeticTupleIterator citer_2 = citer_1 + make_tuple(Int<0>{}, 5, Int<2>{});
// 做的事情:(42, 2, 7) + (0, 5, 2) = (42, 7, 9)

print(*citer_2);
// 输出:(42,7,_9)
// (带下划线的是编译期常量,不带的是运行期值)

注意这里的「偏移」不是「指针移动了多少字节」,而是「坐标移动到了哪里」。这就是从「寻址模式」到「坐标模式」的本质转变。

4.3 make_inttuple_iter 的实际含义

调用 make_inttuple_iter(a, b, c, ...) 就是创建一个 ArithmeticTupleIterator,其内部起始坐标是 (a, b, c, ...)。在构造 TMA Tensor 时,通常用 make_inttuple_iter(0, 0) 表示从坐标原点 (0, 0) 开始。


5. 最核心的思想跃迁:Stride 不一定是整数

这是整篇文档最关键的一步,也是 TMA Tensor 能够成立的根本原因。

5.1 普通 Layout 的工作原理回顾

// 逻辑坐标 (i, j),stride (s0, s1)
// layout 计算:i*s0 + j*s1

// 例:stride = (1, 100)
// (i, j) -> i*1 + j*100    <-- 这是一个整数

// 这个整数用来 offset 指针:ptr + (i + 100*j)

整个过程产出「整数 offset」,这很合理——因为指针只能用整数偏移。

5.2 如果 Stride 不是整数,会发生什么

文档提出的核心洞察:Layout 的内积计算,它的数学公式是不变的。变的只是 stride 的类型。如果 stride 不是整数,而是某种「可与整数相乘、多个可以相加」的对象,那内积的结果就不再是整数,而是那种对象的线性组合。

💡 思想跃迁的关键

stride 的类型要求不是「必须是整数」,而是:

  1. 支持「与整数相乘」(scalar multiplication)
  2. 支持「多个 stride 贡献值相加」(addition)

满足这两点,Layout 就能在形式不变的情况下,把输出从「整数」变成「坐标」。

5.3 如果 Stride 换成坐标方向基向量

如果把 stride 换成坐标空间里的基向量 (1@0, 1@1)(稍后解释这个符号的含义),那么内积结果就是:

// stride = (1@0, 1@1),代表两个坐标方向
// 逻辑坐标 (i, j) 的 layout 内积:
//   i * (1@0) + j * (1@1)
// = i@0 + j@1
// = (i, j)    <-- 这是一个二维坐标!

// 如果把 stride 交换为 (1@1, 1@0):
//   i * (1@1) + j * (1@0)
// = i@1 + j@0
// = (j, i)    <-- 坐标的行列顺序对调了

这个结论意义重大:通过设计 stride,我们可以控制「逻辑坐标怎样映射到 TMA 坐标」。这与普通 Tensor 里「通过 stride 控制线性地址排列」的思想完全类似,只不过输出从整数变成了多维坐标。


6. Basis Element:坐标空间中的单位基向量

为了表达「坐标方向」,CuTe 引入了 Basis Element,用 E<...>{} 语法表示。这是让 stride 能产出坐标的核心工具。

6.1 最简单的情形:一维基向量

E<>{}     // 普通的 1(退化为整数,与普通 stride 行为相同)
E<0>{}    // 打印为 1@0,表示沿第 0 个坐标轴方向走 1
E<1>{}    // 打印为 1@1,表示沿第 1 个坐标轴方向走 1
E<2>{}    // 打印为 1@2,表示沿第 2 个坐标轴方向走 1

可以把它们类比为标准基向量:

Basis Element类比的直觉理解
E<0>{} = 1@0类似二维向量 (1, 0),沿第 0 维走 1
E<1>{} = 1@1类似二维向量 (0, 1),沿第 1 维走 1
E<2>{} = 1@2类似三维向量 (0, 0, 1),沿第 2 维走 1

6.2 Scaled Basis:乘以整数系数

基向量可以与整数相乘,表示在该方向上走更大的步长:

5  * E<0>{}  // 打印为 5@0,沿第 0 维走 5
5  * E<1>{}  // 打印为 5@1,沿第 1 维走 5
64 * E<1>{}  // 打印为 64@1,沿第 1 维走 64

// 联系到普通 stride:
// 普通整数 stride=5 表示「线性地址每次增 5」
// 5@1 表示「TMA 坐标第 1 维每次增 5」

6.3 Nested Basis:嵌套坐标方向

Basis element 还可以嵌套,以适配层次化的 TMA 坐标:

E<0,1>{}      // 打印为 1@1@0
              // 含义:在第 0 层坐标的第 1 个子分量方向走 1

E<1,0>{}      // 打印为 1@0@1
              // 含义:在第 1 层坐标的第 0 个子分量方向走 1

5 * E<0,1>{}  // 打印为 5@1@0
              // 含义:在那个嵌套方向走 5

嵌套结构的意义在于:TMA 坐标本身可以是多层次的(比如 mode 0 本身就是一个 tuple)。Nested basis element 就是为了对应这种结构而设计的。

6.4 线性组合:多个基向量相加

多个 basis element 可以线性组合,产出一个复合坐标:

// 示例:多项相加
result = 2*(2@1@0) + 3*(1@1) + 4*(5@1) + 7*(1@0@0)
       = 4@1@0 + 3@1 + 20@1 + 7@0@0
       = ((7, 4), 23)    // 最终拼出一个两层嵌套坐标

// 解读:
//   7@0@0  -> 外层第 0 分量里的第 0 子分量加 7
//   4@1@0  -> 外层第 0 分量里的第 1 子分量加 4
//   3@1    -> 外层第 1 分量加 3
//   20@1   -> 外层第 1 分量再加 20,共加 23

这说明:Basis element 的线性组合,可以构造出任意层次、任意分量的复杂 TMA 坐标。这正是 Layout 内积运算的输出——不再是一个整数,而是一个完整的多维坐标对象。


7. 构造真正的 TMA Tensor:完整例子详解

现在所有组件都齐了,来看文档里最终给出的两个完整 TMA Tensor 例子:

7.1 例子 a:标准映射

Tensor a = make_tensor(
    make_inttuple_iter(0, 0),     // iterator 原点:坐标 (0,0)
    make_shape(4, 5),             // 逻辑 shape:4行 5列
    make_stride(E<0>{}, E<1>{})); // stride:(1@0, 1@1)

// 逻辑坐标 (i,j) 的映射:
// i*(1@0) + j*(1@1) = (i, j)  <-- 直接就是 TMA 坐标 (i,j)

// 打印结果:
// (0,0) (0,1) (0,2) (0,3) (0,4)
// (1,0) (1,1) (1,2) (1,3) (1,4)
// (2,0) (2,1) (2,2) (2,3) (2,4)
// (3,0) (3,1) (3,2) (3,3) (3,4)

这个 Tensor 的每个「元素」不是真实数据,而是「应该传给 TMA 的坐标」。对它切 tile 或 partition,得到的就是对应分块的 TMA 坐标。

7.2 例子 b:坐标轴对调

Tensor b = make_tensor(
    make_inttuple_iter(0, 0),
    make_shape(4, 5),
    make_stride(E<1>{}, E<0>{})); // stride 交换:(1@1, 1@0)

// 逻辑坐标 (i,j) 的映射:
// i*(1@1) + j*(1@0) = (j, i)  <-- TMA 坐标变成 (j,i),行列互换

// 打印结果:
// (0,0) (1,0) (2,0) (3,0) (4,0)
// (0,1) (1,1) (2,1) (3,1) (4,1)
// (0,2) (1,2) (2,2) (3,2) (4,2)
// (0,3) (1,3) (2,3) (3,3) (4,3)

逻辑上的「行列顺序」没变——还是 4×5 的 tensor。但每个位置对应的 TMA 坐标里行列对调了。这完全类比于普通 Tensor 里「通过 stride 控制行优先/列优先」的做法。


8. 逐项解读那串「最吓人的打印」

文档里出现的这串打印是整篇文章最吓人的地方:

ArithTuple(0,_0,_0,_0) o ((_128,_64),2,3,1):((_1@0,_1@1),_64@1,_1@2,_1@3)

现在逐项翻译:

8.1 ArithTuple(0,_0,_0,_0)

这是 Tensor 的 iterator 当前值,即「坐标原点」:

8.2 o(小写字母 o)

这个符号是 CuTe 的标准打印格式,表示:iterator o layout,即「这个 Tensor = 某迭代器 + 某布局」。

8.3 ((_128,_64),2,3,1) — Shape

这是 Tensor 的逻辑 shape,它是分层的:

整个逻辑空间的总元素数 = 128 × 64 × 2 × 3 × 1 = 49152。

8.4 ((_1@0,_1@1),_64@1,_1@2,_1@3) — Stride

这是 stride,每一项都是 basis element,不是整数。对应 shape 的每一个 mode:

Shape ModeStride含义
(_128,_64) 内的 128_1@0沿第 0 个 TMA 坐标维度走 1
(_128,_64) 内的 64_1@1沿第 1 个 TMA 坐标维度走 1
2_64@1沿第 1 个 TMA 坐标维度走 64
3_1@2沿第 2 个 TMA 坐标维度走 1
1_1@3沿第 3 个 TMA 坐标维度走 1

直观理解这串 stride:

整串打印真正的含义不是「这个 Tensor 在显存里怎么排布」,而是:「这个 Tensor 的逻辑坐标,会被翻译成什么样的 TMA 多维坐标」


9. 为什么 CuTe 原有工具链可以直接复用

TMA Tensor 的最大设计亮点,是它与 CuTe 原有的 tile/partition/slice 体系完全兼容。

9.1 CuTe 工具链对 Tensor 的唯一要求

tile、partition、slice 等操作的实现,其实只依赖于:

TMA Tensor 满足这三条——只是 stride 的类型从整数换成了 basis element,但 basis element 完全支持「与整数相乘」和「相加」(这正是它被专门设计成这样的原因)。

9.2 切 tile 的结果

// 对 TMA Tensor 切 tile
Tensor tma_tensor = make_tensor(make_inttuple_iter(0,0,0),
                                make_shape(M, N, K),
                                make_stride(E<0>{}, E<1>{}, E<2>{}));

// 切出 (tile_m, tile_n) 大小的 tile
auto tile = local_tile(tma_tensor, make_shape(tile_m, tile_n), ...);

// tile 仍然是一个 TMA Tensor
// 访问 tile 的任一元素,得到的是该元素对应的 TMA 坐标
// 直接传给 TMA copy 指令即可

9.3 一个具体的 kernel 场景

在 Hopper 上写 GEMM kernel 时,通常的完整流程是:

  1. Host 端:为矩阵 A、B 建好 TMA descriptor
  2. Host 端:构造对应的 TMA Tensor(坐标 Tensor),描述整个 GMEM 的逻辑结构
  3. Kernel 端:通过 CTA tile 划分,得到每个 CTA 对应的 TMA Tensor 分块
  4. Kernel 端:从分块 TMA Tensor 里取到「该 CTA 应该搬运的 tile 的 TMA 坐标」
  5. Kernel 端:用这个坐标 + descriptor + smem pointer 发 TMA load/store 指令

整个过程里,CTA 之间的划分、warpgroup 之间的数据分配,全部复用普通 CuTe 的 tile/partition 工具,没有任何额外的坐标计算逻辑需要手写。

你说得对,过度使用修辞确实会掩盖技术本身的清晰度,尤其是在讨论底层硬件和位运算时,我们需要的是绝对的严谨和直白。我们直接去掉所有花哨的比喻,回归到最朴素的工程逻辑、数学公式和具体的位运算步骤上。

下面是一份完全基于硬件数据、计算公式和代码逻辑的 Swizzle<B, M, S> 详细技术说明。


Swizzle 的基本工作原理

Swizzle 的实质是对一维内存索引(Index)进行位运算(异或,XOR),从而改变数据在物理内存中的存放位置。

它的目的只有一个:当多线程以非连续的跨度(Stride)读取内存时,避免多个请求落入同一个 Shared Memory Bank(引发 Bank Conflict),同时还要保证 128-bit(16 Bytes)向量化读取的内存连续性。

其内部运算逻辑伪代码如下:

C++

// idx 为逻辑元素索引
int target_bits  = (idx >> M) & ((1 << B) - 1);       // 提取准备被替换的低位
int control_bits = (idx >> (M + S)) & ((1 << B) - 1); // 提取用于触发异或的高位
int xor_result   = target_bits ^ control_bits;        // 执行异或
// 最终将 idx 中的 target_bits 替换为 xor_result 得到物理索引

三个模板参数 <B, M, S> 的严格定义与计算方法

为了计算这三个参数,我们必须基于以下两个不变的硬件物理常量:

  1. 物理行宽:Shared Memory 有 32 个 Bank,每个 Bank 宽 4 Bytes,一整行的物理宽度固定为 128 Bytes
  2. 向量化读取要求:为了达到峰值带宽,GPU 采用 128-bit 的向量化加载指令。这要求内存中连续的 16 Bytes 绝对不能被任何机制打乱。

1. 参数 M (MBase):保留不参与异或的低位个数

定义:从第 0 位开始,有 M 个比特位不参与异或运算,原样保留。

计算目标:保证连续的 16 Bytes 在物理内存中依然是连续的。

推导与计算步骤

因为 CuTe 中传入的索引是元素索引(Element Index),所以我们需要算出 16 Bytes 能装下多少个该类型的元素,并求其以 2 为底的对数。

2. 参数 B (BBits):参与异或运算的比特位个数

定义:提取多少个比特位进行 XOR 计算。这决定了内存行被切分为多少个独立的交错块。

计算目标:完整覆盖 128 Bytes 的物理行。

推导与计算步骤

3. 参数 S (SShift):控制位的左移偏移量

定义:控制高位是从低位起点(M)再向左移动 S 位提取的。

计算目标:确保只有当内存地址跨越了一整行(128 Bytes)时,才触发一次异或模式的改变。

推导与计算步骤

我们需要对比“基本块边界”和“物理行边界”在二进制位上的差异。


明确的使用场景判断

在工程中,是否引入 Swizzle 并不取决于“是不是矩阵”,而是取决于数据的读写步长(Stride)特征

1. 应当使用 Swizzle 的场景

2. 不应使用 Swizzle 的场景

矩阵行号 (高位)第 0 列第 1 列第 2 列第 3 列
Row 0Bank 0Bank 1Bank 2Bank 3
Row 1Bank 1Bank 0Bank 3Bank 2
Row 2Bank 2Bank 3Bank 0Bank 1
Row 3Bank 3Bank 2Bank 1Bank 0

经过正确Swizzle的矩阵可以实现行列读取都不导致bank conflict

Tensor Core mma 硬件级空间切分与数据交织法则

详细的各个架构mma的不同布局详见:

一、 核心发现 (The Core Discovery)

在执行底层的 mma.m8n8k4(矩阵乘加)指令时,Warp 内 32 个线程对操作数(矩阵 A 和 B)的读取与分配,并非基于随意的软件逻辑,而是严格遵循基于 Half-Warp(半线程束)的物理空间对称切分契约结论陈述: 对于输入矩阵,硬件沿着进行乘法运算的核心维度(矩阵 A 的 M 维度,矩阵 B 的 N 维度),将其强行“腰斩”为上下/左右两半。

无论数据在内存中是 Row-Major(行主序)还是 Column-Major(列主序),这个基于物理线程 ID 的空间维度分配法则坚如磐石。


二、 深度推导与逻辑拆解:输入端的三维映射

为了证明这一发现的普适性,我们将矩阵 $C = A \times B$ 的输入端拆解如下:

1. 矩阵 A ($8 \times 4$):M 维度(行)的绝对切分

在 $M \times K$ 的坐标系中,$K$ 维度的数据由单线程内部的寄存器组合(如 V0, V1, V2, V3)横向/纵向吞吐。Warp 分组唯一能切分的空间是 M 维度(行)

物理线程组逻辑角色掌管的 M 维度空间对应寄存器行为 (以列主序为例)
T0 - T3低组 (Lower Half-Warp)上半区 ($M = 0 \dots 3$)T0 吞吐第 0 列的 $M_0 \sim M_3$
T16 - T19高组 (Upper Half-Warp)下半区 ($M = 4 \dots 7$)T16 吞吐第 0 列的 $M_4 \sim M_7$

2. 矩阵 B ($4 \times 8$):N 维度(列)的绝对切分

视角切换到 $K \times N$ 坐标系,$K$ 维度依然由内部寄存器消化,此时切分刀刃转向了 N 维度(列)

物理线程组逻辑角色掌管的 N 维度空间对应寄存器行为 (以行主序为例)
T0 - T3低组 (Lower Half-Warp)左半区 ($N = 0 \dots 3$)T0 吞吐第 0 行的 $N_0 \sim N_3$
T16 - T19高组 (Upper Half-Warp)右半区 ($N = 4 \dots 7$)T16 吞吐第 0 行的 $N_4 \sim N_7$

三、 反直觉陷阱:输出端 C 矩阵 ($8 \times 8$) 的二维纠缠

切记:输入端的“完美对称切分”在输出端会瞬间失效。 不能用输入端 $M/N$ 的直觉去推导累加器(Accumulator)的结果分布。

推导过程(以 T0 为例):

  1. T0 握有 $M=0$ 的数据。
  2. 在计算瞬间,硬件不仅让 T0 用自己的 $M=0$ 乘以自己手里的 $N=0 \sim 3$(左半区)。
  3. 硬件的 Crossbar 会把 T16 手里的 $N=4 \sim 7$(右半区)数据硬连线广播 (Broadcast) 给 T0。
  4. 因此,T0 最终计算出的 8 个元素,被打碎并散落在第一象限和第二象限的离散位置(如 $C_{0,0}, C_{0,1}, C_{0,4}, C_{0,5} \dots$)。

结论:输出矩阵 C 的布局是高度碎片化的“交织马赛克(Checkerboard Pattern)”,这是为了在最后将结果写回寄存器堆时,最大化写入带宽并彻底消除写冲突。


分享此帖至: