SANA-Denoiser - Our Architecture Design¶
Repurposing [[SANA]] 1.6B DiT as an image restoration model. Combines efficient linear attention with [[Paired Training for Restoration]] and [[Temporal Tiling]] via [[Block Causal Linear Attention]].
Why SANA for Restoration¶
| Property | SANA 1.6B | Step1X-Edit (RealRestorer) | FLUX-dev |
|---|---|---|---|
| Params | 1.6B | ~15B | 12B |
| Attention | Linear O(N) | Quadratic O(N^2) | Quadratic O(N^2) |
| VAE compression | 32x ([[DC-AE]]) | 8x | 8x |
| Tokens at 1024px | 1024 | 16384 | 4096 |
| Tokens at 4K | 16384 | 262144 (!) | 65536 |
| Speed (1024px) | 1.2s | ~15s | 23s |
SANA is 10x smaller, 4x fewer tokens, linear complexity. For restoration where we need high-res processing, this is decisive.
Architecture Changes (Minimal)¶
1. Input Conditioning: Channel Concat¶
degraded → DC-AE.encode → condition_latents [B, 32, H, W]
target → DC-AE.encode → latents [B, 32, H, W]
x_t = (1-σ)*noise + σ*latents [B, 32, H, W]
model_input = concat([x_t, condition_latents], dim=1) [B, 64, H, W]
projection = Conv2d(64, 32, 1) # 1x1 conv, ~1K params
# Identity init for noise channels, zero init for condition channels
# At step 0: model = pretrained T2I behavior
# Condition signal learned gradually during fine-tuning
model(projection(model_input), timestep, text_embeddings)
Total new parameters: 1,024 (32 x 32 x 1 x 1 conv kernel). Compare: ControlNet = ~800M.
2. Text Conditioning for Degradation Type¶
Prompt describes what to restore: - "Remove gaussian noise, restore sharp details" - "Remove JPEG compression artifacts" - "Enhance this low-light image" - "Clean and restore this image"
Leverages SANA's Gemma-2-2B text encoder for degradation-type understanding.
3. Temporal Tiling for High-Resolution¶
For images > training resolution (e.g., 4K product photos):
4096x4096 image
↓ split into overlapping 1024px tiles (raster scan)
↓ each tile: DC-AE encode → 32x32x32 latent
↓ denoise with Block Causal Linear Attention
↓ (running sum S, Z from previous tiles = global context)
↓ stitch latents with linear blending in overlap
↓ DC-AE decode full stitched latent
4096x4096 restored image
Memory: constant O(D^2) cache + one tile latent. Processes any resolution.
Training Strategy¶
Phase 1: LoRA (fast iteration)¶
- Rank 32, target: attn.to_q/k/v/out + input projection conv
- 512px, 10K steps, DIV2K + Flickr2K synthetic degradation
- Evaluate: does it learn to denoise at all?
Phase 2: Full Fine-Tune (if LoRA insufficient)¶
- Unfreeze all transformer params + projection
- VAE stays frozen
- Gradient checkpointing for memory
- Curriculum: 512px → 1024px
Phase 3: Temporal Tiling (inference-only first)¶
- No retraining needed - causal attention is native to linear attention
- Just implement the tile loop + S, Z accumulation
- If quality insufficient: fine-tune with multi-tile samples
Dataset¶
Source: DIV2K (800) + Flickr2K (2650) = 3450 clean images Degradation: 5-8 variants per image = 17K-28K pairs
| Degradation | Params | Prompt |
|---|---|---|
| Gaussian noise | σ=10,15,25,35,50 | "Remove gaussian noise sigma {σ}" |
| JPEG | q=15,25,40 | "Remove JPEG artifacts quality {q}" |
| Blur | k=3,5,7,9 | "Remove blur, restore sharpness" |
| Downscale | 2x,3x,4x | "Upscale and restore details" |
| Combined | 2-3 random | "Restore this degraded image" |
Evaluation Targets¶
| Benchmark | Metric | Target | SOTA Reference |
|---|---|---|---|
| SIDD val | PSNR | > 38 dB | NAFNet: 40.3 |
| SIDD val | SSIM | > 0.95 | NAFNet: 0.96 |
| DIV2K (σ=25) | PSNR | > 30 dB | SwinIR: 30.9 |
| Urban100 (σ=25) | PSNR | > 29 dB | SwinIR: 29.5 |
| Temporal tiling | Seam PSNR | > 40 dB | MultiDiffusion baseline |
Project Files¶
happyin-research/
├── sana-fm/
│ ├── data/paired_dataset.py ← paired loader
│ ├── data/degradation.py ← degradation functions
│ ├── configs/img2img_denoise.yaml
│ └── train_flowmatching.py ← modified compute_loss
├── sana-denoiser/
│ ├── prepare_dataset.py ← DIV2K + Flickr2K + degradations
│ ├── train.py ← wrapper
│ ├── temporal_tiling.py ← tile-as-sequence inference
│ └── eval/benchmark.py ← vs SwinIR, NAFNet
Risk Assessment¶
| Risk | Likelihood | Mitigation |
|---|---|---|
| DC-AE 32x compression loses fine details | Medium | Compare DC-AE reconstruction vs 8x VAE on jewelry textures |
| Linear attention insufficient for restoration | Low | SANA matches quadratic models on generation; restoration is simpler |
| Temporal tiling adds latency | High | Acceptable: quality > speed for product photography |
| 1.6B too small for complex degradations | Medium | Scale to 4.8B if needed; depth-pruning from 4.8B as fallback |