AuroraScope: Sparse Autoencoders for the Aurora Air Pollution Model

AuroraScope is a suite of 48 sparse autoencoders (SAEs) trained on every transformer block of the Aurora Air Pollution foundation model's Swin3D backbone. These SAEs decompose the model's internal representations into interpretable, monosemantic features, enabling mechanistic interpretability research on a state-of-the-art weather and air quality forecasting system.

Paper: Forthcoming Code: github.com/jasonyhu/mirora Authors: Jason Hu, Makoto Kelp, IvΓ‘n Higuera-Mendieta, Obin Sturm, Marshall Burke

Model Overview

Property Value
Base model microsoft/aurora (Aurora Air Pollution, 1.3B params)
Architecture Top-k sparse autoencoder (ReLU)
Backbone Swin3D with 48 transformer blocks
SAEs trained 48 (8x expansion, one per backbone block)
Training data CAMS global forecasts, Jun 2022 -- May 2023 (365 days)
Grid resolution 0.4 deg (451 x 900), 13 pressure levels
Total parameters ~1.8B across all 48 SAEs

Aurora Backbone Architecture

The Aurora Air Pollution model uses a Swin3D vision transformer backbone organized as an encoder-decoder with skip connections:

             Encoder                           Decoder
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
  β”‚ Stage 0: 6 blocks       β”‚       β”‚ Stage 0: 8 blocks       β”‚
  β”‚   dim = 512             β”‚       β”‚   dim = 2048            β”‚
  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€       β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
  β”‚ Stage 1: 10 blocks      β”‚  ───> β”‚ Stage 1: 10 blocks      β”‚
  β”‚   dim = 1024            β”‚ skip  β”‚   dim = 1024            β”‚
  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€       β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
  β”‚ Stage 2: 8 blocks       β”‚       β”‚ Stage 2: 6 blocks       β”‚
  β”‚   dim = 2048            β”‚       β”‚   dim = 512             β”‚
  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Each stage processes the global atmospheric state at a different resolution. The model tokenizes the input into 180,000 spatial patches (150 x 300 patches across 4 latent levels) using a patch size of 3.

SAE Architecture

Each SAE is a single-layer autoencoder with ReLU activation:

Input x (d_model) --> subtract b_dec --> W_enc --> ReLU --> h (d_hidden) --> W_dec --> add b_dec --> x_hat
  • Encoder: h = ReLU((x - b_dec) @ W_enc + b_enc)
  • Decoder: x_hat = h @ W_dec + b_dec
  • Decoder columns are unit-normalized after each training step
Layer dimension Expansion Hidden dim Params per SAE File size
512 8x 4,096 4.2M 16 MB
1,024 8x 8,192 16.8M 64 MB
2,048 8x 16,384 67.1M 256 MB

Reconstruction Quality (Fraction of Variance Explained)

Evaluated over 20 held-out CAMS initialization batches:

Backbone layer dim Expansion FVE (%)
Encoder stage 0 (6 blocks) 512 8x 99.6 -- 99.8
Encoder stage 1 (10 blocks) 1024 8x 95.2 -- 99.9
Encoder stage 2 (8 blocks) 2048 8x 75.8 -- 80.4
Decoder stage 0 (8 blocks) 2048 8x 76.4 -- 82.3
Decoder stage 1 (10 blocks) 1024 8x 91.2 -- 96.2
Decoder stage 2 (6 blocks) 512 8x 99.6 -- 99.7

The dim-2048 layers (encoder stage 2, decoder stage 0) are the bottleneck of the Swin3D backbone and have the richest representations.

File Structure

aurorascope/
  README.md               # This model card
  aurora_sae.py            # Standalone SAE module (no dependencies beyond PyTorch)
  sae_index.json           # Machine-readable index of all SAEs with metadata
  saes/
    sae_encoder_s0_b0.pt   # Encoder stage 0, block 0
    sae_encoder_s0_b1.pt
    ...
    sae_decoder_s0_b0.pt
    sae_decoder_s0_b1.pt
    ...

Naming Convention

sae_{encoder|decoder}_s{stage}_b{block}.pt
  • encoder / decoder: which half of the backbone
  • s{stage}: stage index (0, 1, or 2)
  • b{block}: block index within the stage

Quick Start

Loading a single SAE

import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

# The SAE module (also available as aurora_sae.py in this repo)
class AuroraSAE(nn.Module):
    def __init__(self, d_model: int = 2048, expansion_factor: int = 8, l1_coeff: float = 3e-4):
        super().__init__()
        self.d_hidden = d_model * expansion_factor
        self.l1_coeff = l1_coeff
        self.W_enc = nn.Parameter(torch.empty(d_model, self.d_hidden))
        self.b_enc = nn.Parameter(torch.zeros(self.d_hidden))
        self.W_dec = nn.Parameter(torch.empty(self.d_hidden, d_model))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        x_cent = x - self.b_dec
        h = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_hat = h @ self.W_dec + self.b_dec
        return x_hat, h

# Download and load an SAE for encoder stage 1, block 5 (dim=1024, 8x)
path = hf_hub_download("hujason/aurorascope", "saes/sae_encoder_s1_b5.pt")
sae = AuroraSAE(d_model=1024, expansion_factor=8)
sae.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
sae.eval()

Hooking into Aurora and extracting features

from aurora import AuroraAirPollution

# Load Aurora
model = AuroraAirPollution()
model.load_checkpoint("microsoft/aurora", "aurora-0.4-air-pollution.ckpt")
model.eval()

# Register a forward hook on a specific backbone layer
activations = []
def hook_fn(module, input, output):
    x = output[0]  # (batch, tokens, d_model)
    activations.append(x.detach())

# Hook encoder stage 1, block 5
handle = model.backbone.encoder_layers[1].blocks[5].register_forward_hook(hook_fn)

# Run inference (see mirora package for batch construction from CAMS data)
with torch.no_grad():
    pred = model(batch)

handle.remove()

# Extract SAE features
x = activations[0].reshape(-1, 1024)  # (num_tokens, d_model)
x_hat, features = sae(x)              # features: (num_tokens, 8192)

# Find the top-activating features
top_features = features.sum(dim=0).topk(10)
print("Top 10 most active features:", top_features.indices.tolist())

Using the SAE index

import json
from huggingface_hub import hf_hub_download

# Load the index for programmatic access
index_path = hf_hub_download("hujason/aurorascope", "sae_index.json")
with open(index_path) as f:
    index = json.load(f)

# Find all SAEs for a specific stage
encoder_s1_saes = [s for s in index["saes"] if s["component"] == "encoder" and s["stage"] == 1]
for sae_info in encoder_s1_saes:
    print(f"Block {sae_info['block']}: dim={sae_info['d_model']}, "
          f"hidden={sae_info['d_hidden']}, file={sae_info['filename']}")

Training Details

Data

SAEs were trained on activations collected from Aurora Air Pollution inference on CAMS (Copernicus Atmosphere Monitoring Service) global forecast data spanning June 2022 through May 2023. Each initialization batch represents a single global atmospheric snapshot at 0.4 degree resolution with:

  • 12 surface variables: 10m wind (u, v), 2m temperature, mean sea level pressure, PM1, PM2.5, PM10, total column CO/NO/NO2/O3/SO2
  • 10 atmospheric variables at 13 pressure levels: temperature, wind (u, v), specific humidity, geopotential, CO, NO, NO2, O3, SO2

Training procedure

  • Optimizer: Adam (lr = 4e-4)
  • Loss: MSE reconstruction + L1 sparsity penalty (coefficient = 3e-4)
  • Epochs: 3 passes over the full year of data
  • Batch size: 1 (each batch is a full global snapshot = 180,000 tokens)
  • Decoder normalization: Unit-norm columns after each gradient step
  • Infrastructure: Trained as a SLURM array job (one GPU per SAE) on Stanford Sherlock HPC

Expansion factor

  • 8x expansion was used for all 48 layers

Intended Use

AuroraScope is designed for mechanistic interpretability research on the Aurora weather/air quality model. Example applications:

  • Feature discovery: Identify monosemantic features corresponding to meteorological phenomena (e.g., wildfire plumes, jet streams, pollution transport)
  • Circuit analysis: Trace how atmospheric information flows through the Swin3D backbone
  • Causal intervention: Steer model behavior by clamping or ablating specific SAE features during inference
  • Perturbation studies: Understand model sensitivity to input changes (e.g., emission increases) through the lens of learned features

Limitations

  • SAEs are trained on a single year of CAMS data (Jun 2022 -- May 2023). Features may not fully capture rare events outside this period.
  • The dim-2048 layers at 8x expansion achieve only 76--82% FVE, meaning ~20% of activation variance is not captured.
  • SAE features are not guaranteed to be monosemantic; some features may be polysemantic or encode artifacts of the training process.
  • These SAEs are specific to the aurora-0.4-air-pollution checkpoint and cannot be used with other Aurora variants without retraining.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for hujason/aurorascope

Base model

microsoft/aurora
Finetuned
(3)
this model