Block Causal Linear Attention¶
Temporal extension of [[SANA]]'s linear attention for sequential processing (video frames or image tiles). Enables constant-memory O(D^2) processing regardless of sequence length.
How Standard Linear Attention Works¶
SANA uses ReLU kernel linear attention:
O_i = phi(Q_i) * S / (phi(Q_i) * Z)
where:
S = sum_j phi(K_j)^T * V_j # running sum, shape [D x D]
Z = sum_j phi(K_j)^T # normalizer, shape [D x 1]
phi(x) = ReLU(x)
Key insight: S and Z are cumulative sums shared across all queries. Computed once = O(ND^2) instead of O(N^2D).
Extension to Causal (Temporal) Processing¶
For video or tile-sequence, enforce causality - frame/tile N can only attend to frames/tiles 0..N:
For tile t:
S_t = S_{t-1} + sum_{j in tile_t} phi(K_j)^T * V_j
Z_t = Z_{t-1} + sum_{j in tile_t} phi(K_j)^T
O_i = phi(Q_i) * S_t / (phi(Q_i) * Z_t) # for queries in tile t
Memory: O(D^2) Constant¶
- S is a D x D matrix (~2240 x 2240 = 5M params for SANA 1.6B)
- Z is a D x 1 vector
- Does NOT grow with number of tiles/frames
- Compare: standard KV cache is O(N * D) and grows linearly
Application to [[Temporal Tiling]]¶
Instead of video frames, treat image tiles as the temporal sequence:
- Raster-scan order - top-left to bottom-right
- Each tile: encode via [[DC-AE]] → 32x32x32 latent (for 1024px tile)
- Process tile through SANA DiT with causal linear attention
- Update running sums S and Z
- Next tile inherits global context from all previous tiles
Overlap Handling¶
For overlapping tiles: - Tokens from overlap zone appear in both tile latents - Use position encoding to distinguish: RoPE with (tile_row, tile_col, local_h, local_w) - Blend overlap latents with linear weights before VAE decode
From SANA-Video Implementation¶
SANA-Video specs: - 2B params, 720p, up to 1 minute, 16 FPS - 36s latency for 5s 720p (vs 1897s for Wan-2.1-14B = 52x faster) - Causal Mix-FFN: caches last frame for temporal convolution causality - LTX-VAE for video encoding
The same mechanism applied to tiles enables: - Process 4K image (4096x4096) as ~16 tiles of 1024x1024 - Each tile sees context from all previous tiles - No seam artifacts on smooth gradients (metal, skin) - Memory: same as single tile + small S, Z cache
Comparison with Other Temporal Approaches¶
| Approach | Memory | Quality | Speed |
|---|---|---|---|
| Independent tiles + overlap avg | O(1) | Poor (seams) | Fast (parallel) |
| Full self-attention across tiles | O(N^2) | Best | Very slow |
| AnimateDiff temporal modules | O(N) KV cache | Good | Medium |
| Block Causal Linear Attention | O(D^2) constant | Good | Fast (linear) |
| SyncDiffusion (gradient sync) | O(1) + grad cost | Good | Slow |
Implementation Notes¶
# Pseudocode for causal tile processing
S = torch.zeros(D, D) # running KV sum
Z = torch.zeros(D, 1) # running K sum
for tile in raster_scan(image):
tile_latent = dc_ae.encode(tile) # [1, 32, 32, 32]
noise = torch.randn_like(tile_latent)
# Denoise with causal context
for step in scheduler.timesteps:
x_t = scheduler.add_noise(tile_latent, noise, step)
# Linear attention uses S, Z as accumulated context
pred = model(x_t, step, text, causal_state=(S, Z))
# Update running sums with this tile's K, V
S += phi(K_tile).T @ V_tile
Z += phi(K_tile).T.sum(dim=-1, keepdim=True)
denoised_tiles.append(pred)
# Stitch and decode
full_latent = stitch_with_blending(denoised_tiles)
output = dc_ae.decode(full_latent)