不通过反转正向传播的方式计算sinkhorn迭代的梯度
问题设定
问题
注:$\odot$ 代表逐元素乘法。
-
输入矩阵: $X \in \mathbb{R}^{n \times n}$。
-
$P = \exp(X)$(element-wise)。
-
通过对 $P$ 进行 Sinkhorn-knopp迭代,得到bistochastic matrix $R = \text{diag}(\alpha) P \text{diag}(\beta)$,其中 $\alpha, \beta \in \mathbb{R}^n_{>0}$ 是缩放因子,满足:
- 行和约束:$R \mathbf{1} = \mathbf{1} \implies \alpha \odot (P\beta) = \mathbf{1}$
- 列和约束:$R^T \mathbf{1} = \mathbf{1} \implies \beta \odot (P^T \alpha) = \mathbf{1}$
-
损失函数: $L = f(R)$,令 $G = \nabla_R L = \frac{\partial L}{\partial R}$ 为已知梯度。
目标
$L$ 对 $X$ 的梯度:$\frac{\partial L}{\partial X}$。
TLDR
通过使用CG方法求解下列方程:
$$\begin{bmatrix} I & R \\ R^T & I \end{bmatrix} \begin{bmatrix} u \\ v \end{bmatrix} = \begin{bmatrix} (G \odot R) \mathbf{1} \\ (G \odot R)^T \mathbf{1} \end{bmatrix}$$可以得到 $L$ 对 $X$ 的梯度:
$$\nabla_X L = (G - u \mathbf{1}^T - \mathbf{1} v^T) \odot R$$在前向sinkhorn-knopp迭代充分收敛的条件下,该方法能够收敛。
求解
我们的目标是求 $\frac{\partial L}{\partial X}$。根据链式法则:
$$\frac{\partial L}{\partial X} = \frac{\partial L}{\partial R} \cdot \frac{\partial R}{\partial P} \cdot \frac{\partial P}{\partial X}$$由于 $P_{ij} = e^{X_{ij}} \implies \frac{\partial P_{ij}}{\partial X_{ij}} = P_{ij}$, 若能求出 $\frac{\partial L}{\partial P}$,最终结果就是 $\nabla_X L = \nabla_P L \odot P$。
通过对 Sinkhorn 的平衡条件进行隐函数求导,可以证明得到 $\nabla_X L$ 的计算公式如下(过程略……):
$$\nabla_X L = (G - u \mathbf{1}^T - \mathbf{1} v^T) \odot R$$其中 $u, v \in \mathbb{R}^n$ 是下列线性系统的解, 等号右边分别是$G \odot R$ 的行和和列和:
$$\begin{cases} u + R v = (G \odot R) \mathbf{1} \\ R^T u + v = (G \odot R)^T \mathbf{1} \end{cases}$$求解线性系统
将上述方程改写成矩阵形式:
$$\begin{bmatrix} I & R \\ R^T & I \end{bmatrix} \begin{bmatrix} u \\ v \end{bmatrix} = \begin{bmatrix} (G \odot R) \mathbf{1} \\ (G \odot R)^T \mathbf{1} \end{bmatrix} = b_0$$求解上述线性系统,得到 $u$ 和 $v$。
组装梯度
得到 $u$ 和 $v$ 后,代入:
$$\frac{\partial L}{\partial X_{ij}} = (G_{ij} - u_i - v_j) R_{ij} = (G_{ij} - (u_i + v_j)) R_{ij}$$对于每一个$i,j$,我们需要从上述方程中解得的就是$u_i + v_j$。
性质
记$A=\begin{bmatrix} I & R \\ R^T & I \end{bmatrix}$。
1. 多解
证明:
考虑非零向量 $w = \begin{bmatrix} \mathbf{1} \\ -\mathbf{1} \end{bmatrix}$(其中 $\mathbf{1}$ 为全 1 的 $n$ 维列向量)。 计算 $Aw$:
$$Aw = \begin{bmatrix} I & R \\ R^T & I \end{bmatrix} \begin{bmatrix} \mathbf{1} \\ -\mathbf{1} \end{bmatrix} = \begin{bmatrix} I\mathbf{1} - R\mathbf{1} \\ R^T\mathbf{1} - I\mathbf{1} \end{bmatrix}$$根据bistochastic matrix性质 $R\mathbf{1} = \mathbf{1}$ 和 $R^T\mathbf{1} = \mathbf{1}$:
$$Aw = \begin{bmatrix} \mathbf{1} - \mathbf{1} \\ \mathbf{1} - \mathbf{1} \end{bmatrix} = \mathbf{0}$$由于存在非零向量在 $A$ 的零空间(Null space)中,故 $\det(A) = 0$。
直观理解:bistochastic matrix的行和列和存在冗余(例如,行和为1,因此每行其实只需知道前n-1个元素)。
2. 不变量
1) 线性方程组 $Ax = b$ 的解空间
由于 $A$ 是奇异的,对于给定的向量 $b$,如果方程有解,则必有无穷多解。其通解形式为:
$$x = \begin{bmatrix} u \\ v \end{bmatrix} = \begin{bmatrix} u_0 \\ v_0 \end{bmatrix} + k \begin{bmatrix} \mathbf{1} \\ -\mathbf{1} \end{bmatrix} = \begin{bmatrix} u_0 + k\mathbf{1} \\ v_0 - k\mathbf{1} \end{bmatrix}$$其中 $\begin{bmatrix} u_0 \\ v_0 \end{bmatrix}$ 是一个特解,$k$ 是任意实数标量。
2) 不变量
虽然解 $x$ 包含不确定的偏移量 $k$,但我们的计算目标是确定的。 我们的计算目标是矩阵 $M$,定义为:
$$M = u\mathbf{1}^T + \mathbf{1}v^T \quad (\text{即 } M_{ij} = u_i + v_j)$$证明唯一性: 将含有自由变量 $k$ 的通解代入 $M$ 的表达式:
$$M(k) = (u_0 + k\mathbf{1})\mathbf{1}^T + \mathbf{1}(v_0 - k\mathbf{1})^T$$利用矩阵分配律展开:
$$M(k) = u_0\mathbf{1}^T + k(\mathbf{1}\mathbf{1}^T) + \mathbf{1}v_0^T - k(\mathbf{1}\mathbf{1}^T)$$消去 $k$ 相关项:
$$M(k) = u_0\mathbf{1}^T + \mathbf{1}v_0^T = M_{\text{fixed}}$$结论: 对于 $Ax=b$ 的任何解 $x$,由它们计算得到的矩阵 $M$ 是确定的。即:
$$M = f(R, b)$$$M$ 是 $R$ 和 $b$ 的确定函数,不受具体解的影响;因此只要求解能收敛,就可以计算出正确的梯度。
3. 形式变换
从原系统消元:
$$R^T(s_r - Rv) + v = s_c \implies (I - R^T R)v = s_c - R^T s_r$$我们可以得到新的方程 $S\tilde{v} = b$,其中 $S = I - R^T R$, $b = s_c - R^T s_r$,$\tilde{v}$ 是新系统的解。
其中, $S$ 是对称半正定的:
1) 对称性
$$S^T = I - (R^T R)^T = I - R^T R = S \quad \checkmark$$2) 半正定性
对任意 $x \in \mathbb{R}^n$:
$$x^T S x = \|x\|_2^2 - \|Rx\|_2^2$$只需证 $\|Rx\|_2 \leq \|x\|_2$。
由行随机性($\sum_j R_{ij} = 1$)及 Jensen 不等式:
$$\left(\sum_j R_{ij} x_j\right)^2 \leq \sum_j R_{ij} x_j^2$$对 $i$ 求和:
$$\|Rx\|_2^2 = \sum_i \left(\sum_j R_{ij} x_j\right)^2 \leq \sum_i \sum_j R_{ij} x_j^2 = \sum_j x_j^2 \sum_i R_{ij} = \|x\|_2^2$$故 $x^T S x \geq 0$.
由于
$$S\mathbf{1} = \mathbf{1} - R^T(R\mathbf{1}) = \mathbf{1} - R^T\mathbf{1} = \mathbf{0}$$故$S$不是正定的.
算法
1)准备右端项
$$s_r = (G \odot R)\mathbf{1}, \quad s_c = (G \odot R)^T\mathbf{1}$$2)构建半正定系统
$$S = I - R^T R$$以及
$$b = s_c - R^T s_r$$3)用CG求解
$$S \, \tilde{v} = b$$4)构造解
$$u = s_r - R\tilde{v}$$$$v = \tilde{v}$$
5)组装结果
$$M_{ij} = u_i + v_j$$6)最终梯度
$$\nabla_X L = (G - M) \odot R$$PyTorch 实现
"""
Sinkhorn Backward Pass: n×n Rank-0 (Singular) System with Manual CG
Solves (I - R^T R) ṽ = b without rank-1 correction using manual conjugate gradient
"""
import torch
dtype = torch.float32
batch = 10001
n = 4
iters = 48
EPS = 1e-11
print(f"{n = }")
print(f"{iters = }")
print(f"{batch = }")
def sinkhorn_forward(M, iters=20):
"""Standard Sinkhorn forward pass"""
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / R.sum(-2, keepdim=True)
R = R / R.sum(-1, keepdim=True)
return R, P
def batch_cg_solve_singular(A, b):
"""
Manual Conjugate Gradient solver for potentially singular systems.
A: (batch, n, n) - system matrices
b: (batch, n) - right hand side
"""
batch_size, n, _ = A.shape
device = A.device
dtype = A.dtype
# CG Initialization
x = torch.zeros_like(b)
r = b.clone()
p = r.clone()
rs_old = torch.einsum("bi,bi->b", r, r)
# CG Iteration
# Iteration count is n, which is theoretically guaranteed by CG algorithm
for i in range(n):
Ap = torch.einsum("bij,bj->bi", A, p)
pAp = torch.einsum("bi,bi->b", p, Ap)
alpha = rs_old / (pAp + EPS)
x += torch.einsum("b,bi->bi", alpha, p)
r -= torch.einsum("b,bi->bi", alpha, Ap)
rs_new = torch.einsum("bi,bi->b", r, r)
beta = rs_new / (rs_old + EPS)
p = r + torch.einsum("b,bi->bi", beta, p)
rs_old = rs_new
return x
def sinkhorn_backward_n_rank0(grad_R, R, cg_iters=10):
"""
Rank-0 method: Solve n×n singular system WITHOUT rank-1 correction
Uses manual Conjugate Gradient to solve (I - R^T R) ṽ = b
Algorithm steps:
1. r = (G ⊙ R)1, c = (G ⊙ R)^T 1
2. S0 = I - R^T R (SINGULAR), b = c - R^T r
3. Solve: S0 ṽ = b using manual CG
4. u = r - R ṽ, v = ṽ (CG naturally finds ~zero mean solution)
5. M_{ij} = u_i + v_j
6. ∇_X L = (G - M) ⊙ R
"""
batch_size, n, _ = R.shape
device = R.device
dtype = R.dtype
R_detached = R.detach()
G = grad_R
# Step 1: Prepare RHS
r = (R_detached * G).sum(dim=-1) # shape (batch, n)
c = (R_detached * G).sum(dim=-2) # shape (batch, n)
# Step 2: Build n×n SINGULAR system (no rank-1 correction)
# S0 = I - R^T R
R_T = torch.einsum("bij->bji", R_detached)
RTR = torch.einsum("bij,bjk->bik", R_T, R_detached)
eye = torch.eye(n, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, -1, -1)
S0 = eye - RTR # SINGULAR matrix
b = c - torch.einsum("bij,bj->bi", R_T, r)
# Debug: Compute eigenvalues to verify singularity
eigenvalues = torch.linalg.eigvalsh(S0)
min_eig = eigenvalues.min(dim=-1).values
max_eig = eigenvalues.max(dim=-1).values
print("Rank-0 system eigenvalue statistics:")
print(f" Min eigenvalue: {min_eig.min().item():.5f}")
print(f" Max eigenvalue: {max_eig.max().item():.5f}")
print(f" Near-zero eigenvalues exist: {(eigenvalues.abs() < 1e-6).any().item()}")
# Step 3: Solve S0 ṽ = b using manual CG
v_tilde = batch_cg_solve_singular(S0, b, max_iter=cg_iters)
# Step 4: Construct solution
# CG naturally produces minimum-norm solution (~zero mean)
u = r - torch.einsum("bij,bj->bi", R_detached, v_tilde)
v = v_tilde
# Step 5: Assemble M_{ij} = u_i + v_j
M = u.unsqueeze(-1) + v.unsqueeze(-2)
# Step 7: Final gradient
grad_X = (G - M) * R_detached
return grad_X
######################################################################
# Test Setup
######################################################################
# Generate random input
dist = torch.distributions.uniform.Uniform(0.0, 4.0)
M = dist.sample((batch, n, n))
M.requires_grad_()
# Forward pass (shared)
R, P = sinkhorn_forward(M, iters)
loss_weight = torch.randn_like(R)
######################################################################
# Method A: Autograd (Reference)
######################################################################
M.grad = None
loss_a = (R * loss_weight).sum()
loss_a.backward()
grad_M_autograd = M.grad.detach().clone()
######################################################################
# Method B: Rank-0 CG (Singular system, manual CG)
######################################################################
grad_R = loss_weight
grad_M_rank0_cg = sinkhorn_backward_n_rank0(grad_R, R, cg_iters=n)
######################################################################
# Comparison
######################################################################
g_ref = grad_M_autograd
g_rank0 = grad_M_rank0_cg
# Compute differences
abs_diff = (g_ref - g_rank0).abs()
rel_diff = abs_diff / (g_ref.abs() + 1e-12)
MAE = abs_diff.mean(dim=(-1, -2))
max_abs_diff = abs_diff.reshape(batch, -1).max(-1).values
mean_rel_diff = rel_diff.mean(dim=(-1, -2))
max_rel_diff = rel_diff.reshape(batch, -1).max(-1).values
print("\n" + "=" * 60)
print("GRADIENT COMPARISON: Autograd vs Rank-0 CG")
print("=" * 60)
print(f"Max MAE: {MAE.max().item():.6e}")
print(f"Max max_abs_diff: {max_abs_diff.max().item():.6e}")
print(f"Max mean_rel_diff: {mean_rel_diff.max().item():.6e}")
print(f"Max max_rel_diff: {max_rel_diff.max().item():.6e}")
print("\n" + "=" * 60)
print("SAMPLE GRADIENTS (first batch)")
print("=" * 60)
print(f"\nAutograd reference:\n{g_ref[0]}")
print(f"\nRank-0 CG method:\n{g_rank0[0]}")
print(f"\nAbsolute difference:\n{abs_diff[0]}")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
Triton 实现
实现细节
在 Triton kernel 中,每个 thread block 处理一批(tilesize 个)独立的 $n \times n$ 系统。CG 的每次迭代需要计算 $S p = (I - R^T R) p$。
最直接的实现是在 kernel 入口处预计算 $R^T R$,然后在 CG 循环中复用:
# 原始实现(tune_triton_new.py)
RTR = tl.dot(RT, R, input_precision="tf32") # 预计算
for _ in range(n):
Sp = p - tl.dot(RTR, p) # 每次迭代 1 次 matvec
但是这样实现很慢:这里使用matmul的pattern和GEMM不同,只在开头计算一次,不足以填满 Tensor Core 的流水线,利用率低;实测比处理 $2n \times 2n$ 系统的原始实现还慢约 2 倍。
注意到 $S p$ 可以分解为两次连续的 matvec,而无需显式存储 $R^T R$:
$$S p = (I - R^T R) p = p - R^T \underbrace{(R p)}_{\text{中间结果}}$$即先算 $q = Rp$($n \times 1$),再算 $R^T q$($n \times 1$)。
对应的 Triton 实现:
@triton.jit
def matvec_S(R, x):
Rx = tl.dot(R, x, input_precision="ieee") # q = Rx
RT = R.permute(0, 2, 1)
RTRx = tl.dot(RT, Rx, input_precision="ieee") # R^T q
return x - RTRx
CG 循环中直接调用,不再持有 RTR:
for _ in range(n_stream):
Sp = matvec_S(R, p)
pSp = tl.sum(p * Sp, ...)
...
与处理 $2n \times 2n$ 系统的baseline相比,在下列代码的设定下,可以加速1.4x。
代码
from icecream import ic
import torch
import einops as ein
import triton
import triton.language as tl
from tqdm import trange
import time
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: int | None):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
dtype = torch.float32
EPS = tl.constexpr(1e-10)
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / R.sum(-2, keepdim=True)
R = R / R.sum(-1, keepdim=True)
return R, P
@triton.jit
def matvec_S(R, x):
"""
S = I - R^T R, perform S @ x WITHOUT materializing RTR.
Computes: x - R^T (R x) using two matvecs.
R: (tilesize, n, n)
x: (tilesize, n, 1)
returns: (tilesize, n, 1)
"""
Rx = tl.dot(R, x, input_precision="ieee") # (tilesize, n, 1)
RT = R.permute(0, 2, 1)
RTRx = tl.dot(RT, Rx, input_precision="ieee") # (tilesize, n, 1)
return x - RTRx
@triton.autotune(
configs=[
triton.Config({"tilesize": tilesize}, num_stages=1, num_warps=num_warps)
for tilesize in [1, 2, 4, 8, 16, 32, 64]
for num_warps in [1, 2, 4, 8]
],
key=[],
)
@triton.jit
def sinkhorn_bwd_implicit_cg_kernel(
seqlen,
out,
dout,
res,
out_stride_0,
out_stride_1,
out_stride_2,
dout_stride_0,
dout_stride_1,
dout_stride_2,
res_stride_0,
res_stride_1,
res_stride_2,
n_stream: tl.constexpr,
tilesize: tl.constexpr,
):
out_desc = tl.make_tensor_descriptor(
out,
shape=[seqlen, n_stream, n_stream],
strides=[out_stride_0, out_stride_1, out_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
dout_desc = tl.make_tensor_descriptor(
dout,
shape=[seqlen, n_stream, n_stream],
strides=[dout_stride_0, dout_stride_1, dout_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
res_desc = tl.make_tensor_descriptor(
res,
shape=[seqlen, n_stream, n_stream],
strides=[res_stride_0, res_stride_1, res_stride_2],
block_shape=[tilesize, n_stream, n_stream],
)
seq_off = tl.program_id(0) * tilesize
R = out_desc.load([seq_off, 0, 0])
RT = R.permute(0, 2, 1)
dR = dout_desc.load([seq_off, 0, 0])
# Step 1: s_r = (G ⊙ R) 1, s_c = (G ⊙ R)^T 1
RdR = R * dR
s_r = tl.sum(RdR, axis=-1).expand_dims(-1) # (tilesize, n, 1)
s_c = tl.sum(RdR, axis=-2).expand_dims(-1) # (tilesize, n, 1)
# Step 2: b = s_c - R^T s_r
b = s_c - tl.dot(RT, s_r, input_precision="ieee") # (tilesize, n, 1)
# Step 3: CG to solve (I - R^T R) x = b
# Key optimization: do NOT precompute RTR
# Instead, each matvec_S call does: x - R^T(Rx) (two n×1 matvecs).
x = tl.zeros((tilesize, n_stream, 1), dtype=tl.float32)
r = b - matvec_S(R, x) # residual = b - S x = b (since x=0)
p = r
r_normsq = tl.sum(r * r, axis=1, keep_dims=True)
for _ in range(n_stream):
Sp = matvec_S(R, p)
pSp = tl.sum(p * Sp, axis=1, keep_dims=True)
alpha = r_normsq / (pSp + EPS)
x += alpha * p
r -= alpha * Sp
r_new_normsq = tl.sum(r * r, axis=1, keep_dims=True)
beta = r_new_normsq / (r_normsq + EPS)
p = r + beta * p
r_normsq = r_new_normsq
# Step 4: u = s_r - R x, v = x
u = s_r - tl.dot(R, x, input_precision="ieee") # (tilesize, n, 1)
v = x # (tilesize, n, 1)
# Step 5: M_ij = u_i + v_j => M = u 1^T + 1 v^T
# u: (tilesize, n, 1), v^T: (tilesize, 1, n)
v_T = v.reshape(tilesize, 1, n_stream)
M_mat = u + v_T # broadcast -> (tilesize, n, n)
# Step 6: grad = (G - M) ⊙ R
res_tile = (dR - M_mat) * R
res_desc.store([seq_off, 0, 0], res_tile)
def sinkhorn_bwd_implicit_cg(
out: torch.Tensor,
dout: torch.Tensor,
repeat: int,
):
seqlen = out.size(0)
n_stream = out.size(1)
ic(seqlen)
ic(n_stream)
res = torch.empty_like(out)
def grid(META): # META is the dict passed in @triton.autotune
return (triton.cdiv(seqlen, META["tilesize"]), 1, 1)
a = torch.randn(8192, 8192)
for _ in trange(4):
_ = a @ a
# Compile the kernel by running warmup
sinkhorn_bwd_implicit_cg_kernel[grid](
seqlen,
out,
dout,
res,
out.stride(0),
out.stride(1),
out.stride(2),
dout.stride(0),
dout.stride(1),
dout.stride(2),
res.stride(0),
res.stride(1),
res.stride(2),
n_stream,
)
torch.cuda.synchronize()
# start
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(repeat):
sinkhorn_bwd_implicit_cg_kernel[grid](
seqlen,
out,
dout,
res,
out.stride(0),
out.stride(1),
out.stride(2),
dout.stride(0),
dout.stride(1),
dout.stride(2),
res.stride(0),
res.stride(1),
res.stride(2),
n_stream,
)
# end
torch.cuda.synchronize()
end_event.record()
elapsed_time_ms = start_event.elapsed_time(end_event)
# Print timing results
print(f"Kernel execution time ({repeat = }): {elapsed_time_ms:.3f} ms")
print(f"Average time per iteration: {elapsed_time_ms / repeat:.3f} ms")
return res
def main():
seqlen = 65536
n_stream = 16
iters = 100
repeat = 512
######################################################################
# Variable
######################################################################
dist = torch.distributions.uniform.Uniform(0.0, 4.0)
device = torch.device("cuda")
M = dist.sample((seqlen, n_stream, n_stream)).to(device)
M.requires_grad_()
######################################################################
# Shared forward + one shared loss weight
######################################################################
R, P = sinkhorn_forward(M, iters)
loss_weight = torch.randn_like(R)
######################################################################
# Method A: Autograd
######################################################################
loss_a = (R * loss_weight).sum()
loss_a.backward()
grad_M_autograd = M.grad.detach().clone()
######################################################################
# Method B: Implicit differentiation (n×n system, no RTR materialization)
######################################################################
grad_R = loss_weight
grad_M_implicit = sinkhorn_bwd_implicit_cg(R, grad_R, repeat=repeat)
######################################################################
# Compare
######################################################################
g1 = grad_M_autograd
g2 = grad_M_implicit
abs_diff = (g1 - g2).abs()
rel_diff = abs_diff / (g1.abs() + 1e-12)
print("Comparison of gradients dL/dM")
print("--------------------------------")
def format_list(ls):
return [f"{x:.2e}" for x in ls]
MAE = abs_diff.mean(dim=(-1, -2)).tolist()
max_abs_diff = abs_diff.reshape(seqlen, -1).max(-1).values.tolist()
mean_rel_diff = rel_diff.mean(dim=(-1, -2)).tolist()
max_rel_diff = rel_diff.reshape(seqlen, -1).max(-1).values.tolist()
print(f"Max MAE = {max(MAE)}")
print(f"Max max_abs_diff = {max(max_abs_diff)}")
print(f"Max mean_rel_diff = {max(mean_rel_diff)}")
print(f"Max max_rel_diff = {max(max_rel_diff)}")
print("\nGrad (autograd) sample:\n", g1[0, :3, :3])
print("\nGrad (implicit) sample:\n", g2[0, :3, :3])
assert max(MAE) < 1e-7, f"Intolerable difference: MAE = {max(MAE)}"
if __name__ == "__main__":
main()