Skip to content

SANA-Denoiser - Our Architecture Design

Repurposing [[SANA]] 1.6B DiT as an image restoration model. Combines efficient linear attention with [[Paired Training for Restoration]] and [[Temporal Tiling]] via [[Block Causal Linear Attention]].

Why SANA for Restoration

Property SANA 1.6B Step1X-Edit (RealRestorer) FLUX-dev
Params 1.6B ~15B 12B
Attention Linear O(N) Quadratic O(N^2) Quadratic O(N^2)
VAE compression 32x ([[DC-AE]]) 8x 8x
Tokens at 1024px 1024 16384 4096
Tokens at 4K 16384 262144 (!) 65536
Speed (1024px) 1.2s ~15s 23s

SANA is 10x smaller, 4x fewer tokens, linear complexity. For restoration where we need high-res processing, this is decisive.

Architecture Changes (Minimal)

1. Input Conditioning: Channel Concat

degraded → DC-AE.encode → condition_latents [B, 32, H, W]
target   → DC-AE.encode → latents           [B, 32, H, W]

x_t = (1-σ)*noise + σ*latents               [B, 32, H, W]
model_input = concat([x_t, condition_latents], dim=1)  [B, 64, H, W]

projection = Conv2d(64, 32, 1)               # 1x1 conv, ~1K params
# Identity init for noise channels, zero init for condition channels
# At step 0: model = pretrained T2I behavior
# Condition signal learned gradually during fine-tuning

model(projection(model_input), timestep, text_embeddings)

Total new parameters: 1,024 (32 x 32 x 1 x 1 conv kernel). Compare: ControlNet = ~800M.

2. Text Conditioning for Degradation Type

Prompt describes what to restore: - "Remove gaussian noise, restore sharp details" - "Remove JPEG compression artifacts" - "Enhance this low-light image" - "Clean and restore this image"

Leverages SANA's Gemma-2-2B text encoder for degradation-type understanding.

3. Temporal Tiling for High-Resolution

For images > training resolution (e.g., 4K product photos):

4096x4096 image
  ↓ split into overlapping 1024px tiles (raster scan)
  ↓ each tile: DC-AE encode → 32x32x32 latent
  ↓ denoise with Block Causal Linear Attention
  ↓    (running sum S, Z from previous tiles = global context)
  ↓ stitch latents with linear blending in overlap
  ↓ DC-AE decode full stitched latent
4096x4096 restored image

Memory: constant O(D^2) cache + one tile latent. Processes any resolution.

Training Strategy

Phase 1: LoRA (fast iteration)

  • Rank 32, target: attn.to_q/k/v/out + input projection conv
  • 512px, 10K steps, DIV2K + Flickr2K synthetic degradation
  • Evaluate: does it learn to denoise at all?

Phase 2: Full Fine-Tune (if LoRA insufficient)

  • Unfreeze all transformer params + projection
  • VAE stays frozen
  • Gradient checkpointing for memory
  • Curriculum: 512px → 1024px

Phase 3: Temporal Tiling (inference-only first)

  • No retraining needed - causal attention is native to linear attention
  • Just implement the tile loop + S, Z accumulation
  • If quality insufficient: fine-tune with multi-tile samples

Dataset

Source: DIV2K (800) + Flickr2K (2650) = 3450 clean images Degradation: 5-8 variants per image = 17K-28K pairs

Degradation Params Prompt
Gaussian noise σ=10,15,25,35,50 "Remove gaussian noise sigma {σ}"
JPEG q=15,25,40 "Remove JPEG artifacts quality {q}"
Blur k=3,5,7,9 "Remove blur, restore sharpness"
Downscale 2x,3x,4x "Upscale and restore details"
Combined 2-3 random "Restore this degraded image"

Evaluation Targets

Benchmark Metric Target SOTA Reference
SIDD val PSNR > 38 dB NAFNet: 40.3
SIDD val SSIM > 0.95 NAFNet: 0.96
DIV2K (σ=25) PSNR > 30 dB SwinIR: 30.9
Urban100 (σ=25) PSNR > 29 dB SwinIR: 29.5
Temporal tiling Seam PSNR > 40 dB MultiDiffusion baseline

Project Files

happyin-research/
├── sana-fm/
│   ├── data/paired_dataset.py      ← paired loader
│   ├── data/degradation.py         ← degradation functions
│   ├── configs/img2img_denoise.yaml
│   └── train_flowmatching.py       ← modified compute_loss
├── sana-denoiser/
│   ├── prepare_dataset.py          ← DIV2K + Flickr2K + degradations
│   ├── train.py                    ← wrapper
│   ├── temporal_tiling.py          ← tile-as-sequence inference
│   └── eval/benchmark.py           ← vs SwinIR, NAFNet

Risk Assessment

Risk Likelihood Mitigation
DC-AE 32x compression loses fine details Medium Compare DC-AE reconstruction vs 8x VAE on jewelry textures
Linear attention insufficient for restoration Low SANA matches quadratic models on generation; restoration is simpler
Temporal tiling adds latency High Acceptable: quality > speed for product photography
1.6B too small for complex degradations Medium Scale to 4.8B if needed; depth-pruning from 4.8B as fallback