AuroraScope: Sparse Autoencoders for the Aurora Air Pollution Model
AuroraScope is a suite of 48 sparse autoencoders (SAEs) trained on every transformer block of the Aurora Air Pollution foundation model's Swin3D backbone. These SAEs decompose the model's internal representations into interpretable, monosemantic features, enabling mechanistic interpretability research on a state-of-the-art weather and air quality forecasting system.
Paper: Forthcoming Code: github.com/jasonyhu/mirora Authors: Jason Hu, Makoto Kelp, IvΓ‘n Higuera-Mendieta, Obin Sturm, Marshall Burke
Model Overview
| Property | Value |
|---|---|
| Base model | microsoft/aurora (Aurora Air Pollution, 1.3B params) |
| Architecture | Top-k sparse autoencoder (ReLU) |
| Backbone | Swin3D with 48 transformer blocks |
| SAEs trained | 48 (8x expansion, one per backbone block) |
| Training data | CAMS global forecasts, Jun 2022 -- May 2023 (365 days) |
| Grid resolution | 0.4 deg (451 x 900), 13 pressure levels |
| Total parameters | ~1.8B across all 48 SAEs |
Aurora Backbone Architecture
The Aurora Air Pollution model uses a Swin3D vision transformer backbone organized as an encoder-decoder with skip connections:
Encoder Decoder
βββββββββββββββββββββββββββ βββββββββββββββββββββββββββ
β Stage 0: 6 blocks β β Stage 0: 8 blocks β
β dim = 512 β β dim = 2048 β
βββββββββββββββββββββββββββ€ βββββββββββββββββββββββββββ€
β Stage 1: 10 blocks β βββ> β Stage 1: 10 blocks β
β dim = 1024 β skip β dim = 1024 β
βββββββββββββββββββββββββββ€ βββββββββββββββββββββββββββ€
β Stage 2: 8 blocks β β Stage 2: 6 blocks β
β dim = 2048 β β dim = 512 β
βββββββββββββββββββββββββββ βββββββββββββββββββββββββββ
Each stage processes the global atmospheric state at a different resolution. The model tokenizes the input into 180,000 spatial patches (150 x 300 patches across 4 latent levels) using a patch size of 3.
SAE Architecture
Each SAE is a single-layer autoencoder with ReLU activation:
Input x (d_model) --> subtract b_dec --> W_enc --> ReLU --> h (d_hidden) --> W_dec --> add b_dec --> x_hat
- Encoder:
h = ReLU((x - b_dec) @ W_enc + b_enc) - Decoder:
x_hat = h @ W_dec + b_dec - Decoder columns are unit-normalized after each training step
| Layer dimension | Expansion | Hidden dim | Params per SAE | File size |
|---|---|---|---|---|
| 512 | 8x | 4,096 | 4.2M | 16 MB |
| 1,024 | 8x | 8,192 | 16.8M | 64 MB |
| 2,048 | 8x | 16,384 | 67.1M | 256 MB |
Reconstruction Quality (Fraction of Variance Explained)
Evaluated over 20 held-out CAMS initialization batches:
| Backbone layer | dim | Expansion | FVE (%) |
|---|---|---|---|
| Encoder stage 0 (6 blocks) | 512 | 8x | 99.6 -- 99.8 |
| Encoder stage 1 (10 blocks) | 1024 | 8x | 95.2 -- 99.9 |
| Encoder stage 2 (8 blocks) | 2048 | 8x | 75.8 -- 80.4 |
| Decoder stage 0 (8 blocks) | 2048 | 8x | 76.4 -- 82.3 |
| Decoder stage 1 (10 blocks) | 1024 | 8x | 91.2 -- 96.2 |
| Decoder stage 2 (6 blocks) | 512 | 8x | 99.6 -- 99.7 |
The dim-2048 layers (encoder stage 2, decoder stage 0) are the bottleneck of the Swin3D backbone and have the richest representations.
File Structure
aurorascope/
README.md # This model card
aurora_sae.py # Standalone SAE module (no dependencies beyond PyTorch)
sae_index.json # Machine-readable index of all SAEs with metadata
saes/
sae_encoder_s0_b0.pt # Encoder stage 0, block 0
sae_encoder_s0_b1.pt
...
sae_decoder_s0_b0.pt
sae_decoder_s0_b1.pt
...
Naming Convention
sae_{encoder|decoder}_s{stage}_b{block}.pt
encoder/decoder: which half of the backbones{stage}: stage index (0, 1, or 2)b{block}: block index within the stage
Quick Start
Loading a single SAE
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
# The SAE module (also available as aurora_sae.py in this repo)
class AuroraSAE(nn.Module):
def __init__(self, d_model: int = 2048, expansion_factor: int = 8, l1_coeff: float = 3e-4):
super().__init__()
self.d_hidden = d_model * expansion_factor
self.l1_coeff = l1_coeff
self.W_enc = nn.Parameter(torch.empty(d_model, self.d_hidden))
self.b_enc = nn.Parameter(torch.zeros(self.d_hidden))
self.W_dec = nn.Parameter(torch.empty(self.d_hidden, d_model))
self.b_dec = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
x_cent = x - self.b_dec
h = F.relu(x_cent @ self.W_enc + self.b_enc)
x_hat = h @ self.W_dec + self.b_dec
return x_hat, h
# Download and load an SAE for encoder stage 1, block 5 (dim=1024, 8x)
path = hf_hub_download("hujason/aurorascope", "saes/sae_encoder_s1_b5.pt")
sae = AuroraSAE(d_model=1024, expansion_factor=8)
sae.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
sae.eval()
Hooking into Aurora and extracting features
from aurora import AuroraAirPollution
# Load Aurora
model = AuroraAirPollution()
model.load_checkpoint("microsoft/aurora", "aurora-0.4-air-pollution.ckpt")
model.eval()
# Register a forward hook on a specific backbone layer
activations = []
def hook_fn(module, input, output):
x = output[0] # (batch, tokens, d_model)
activations.append(x.detach())
# Hook encoder stage 1, block 5
handle = model.backbone.encoder_layers[1].blocks[5].register_forward_hook(hook_fn)
# Run inference (see mirora package for batch construction from CAMS data)
with torch.no_grad():
pred = model(batch)
handle.remove()
# Extract SAE features
x = activations[0].reshape(-1, 1024) # (num_tokens, d_model)
x_hat, features = sae(x) # features: (num_tokens, 8192)
# Find the top-activating features
top_features = features.sum(dim=0).topk(10)
print("Top 10 most active features:", top_features.indices.tolist())
Using the SAE index
import json
from huggingface_hub import hf_hub_download
# Load the index for programmatic access
index_path = hf_hub_download("hujason/aurorascope", "sae_index.json")
with open(index_path) as f:
index = json.load(f)
# Find all SAEs for a specific stage
encoder_s1_saes = [s for s in index["saes"] if s["component"] == "encoder" and s["stage"] == 1]
for sae_info in encoder_s1_saes:
print(f"Block {sae_info['block']}: dim={sae_info['d_model']}, "
f"hidden={sae_info['d_hidden']}, file={sae_info['filename']}")
Training Details
Data
SAEs were trained on activations collected from Aurora Air Pollution inference on CAMS (Copernicus Atmosphere Monitoring Service) global forecast data spanning June 2022 through May 2023. Each initialization batch represents a single global atmospheric snapshot at 0.4 degree resolution with:
- 12 surface variables: 10m wind (u, v), 2m temperature, mean sea level pressure, PM1, PM2.5, PM10, total column CO/NO/NO2/O3/SO2
- 10 atmospheric variables at 13 pressure levels: temperature, wind (u, v), specific humidity, geopotential, CO, NO, NO2, O3, SO2
Training procedure
- Optimizer: Adam (lr = 4e-4)
- Loss: MSE reconstruction + L1 sparsity penalty (coefficient = 3e-4)
- Epochs: 3 passes over the full year of data
- Batch size: 1 (each batch is a full global snapshot = 180,000 tokens)
- Decoder normalization: Unit-norm columns after each gradient step
- Infrastructure: Trained as a SLURM array job (one GPU per SAE) on Stanford Sherlock HPC
Expansion factor
- 8x expansion was used for all 48 layers
Intended Use
AuroraScope is designed for mechanistic interpretability research on the Aurora weather/air quality model. Example applications:
- Feature discovery: Identify monosemantic features corresponding to meteorological phenomena (e.g., wildfire plumes, jet streams, pollution transport)
- Circuit analysis: Trace how atmospheric information flows through the Swin3D backbone
- Causal intervention: Steer model behavior by clamping or ablating specific SAE features during inference
- Perturbation studies: Understand model sensitivity to input changes (e.g., emission increases) through the lens of learned features
Limitations
- SAEs are trained on a single year of CAMS data (Jun 2022 -- May 2023). Features may not fully capture rare events outside this period.
- The dim-2048 layers at 8x expansion achieve only 76--82% FVE, meaning ~20% of activation variance is not captured.
- SAE features are not guaranteed to be monosemantic; some features may be polysemantic or encode artifacts of the training process.
- These SAEs are specific to the
aurora-0.4-air-pollutioncheckpoint and cannot be used with other Aurora variants without retraining.
Model tree for hujason/aurorascope
Base model
microsoft/aurora