AReUReDi / peptide /train.py
Tong Chen
add files
d2693e0
# train.py
# Description: A complete script to train a ReDi model with a continuous time variable t in [0, 1].
import argparse
import math
import os
from functools import partial
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_from_disk
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
# --- Model Architecture ---
# Based on the DiT (Diffusion Transformer) architecture, adapted for discrete data (MDLM).
def modulate(x, shift, scale):
"""
Modulates the input tensor x with a shift and scale.
This is a key component of the DiT architecture, allowing conditioning
on the timestep embedding.
"""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
"""
Embeds a continuous scalar timestep t in [0, 1] into a vector representation.
"""
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t):
# t is shape (batch_size,), needs to be (batch_size, 1) for the Linear layer.
return self.mlp(t.unsqueeze(-1))
class DiTBlock(nn.Module):
"""
A single block of the Diffusion Transformer.
"""
def __init__(self, hidden_size, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
x = x + gate_msa.unsqueeze(1) * attn_output
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_output = self.mlp(x_norm2)
x = x + gate_mlp.unsqueeze(1) * mlp_output
return x
class MDLM(nn.Module):
"""
Masked Diffusion Language Model (MDLM) using a DiT backbone.
"""
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.seq_len = seq_len
self.model_dim = model_dim
self.mask_token_id = vocab_size # Use vocab_size as the ID for the mask token
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) # +1 for the mask token
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
self.time_embedder = TimestepEmbedder(model_dim)
self.transformer_blocks = nn.ModuleList([
DiTBlock(model_dim, n_heads) for _ in range(n_layers)
])
self.final_norm = nn.LayerNorm(model_dim)
self.lm_head = nn.Linear(model_dim, vocab_size)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
if module.weight is not None:
module.weight.data.fill_(1.0)
def forward(self, x, t):
seq_len = x.shape[1]
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
t_embed = self.time_embedder(t)
for block in self.transformer_blocks:
x_embed = block(x_embed, t_embed)
x_embed = self.final_norm(x_embed)
logits = self.lm_head(x_embed)
return logits
# --- Learning Rate Scheduler ---
def get_lr_scheduler(optimizer, warmup_steps, total_steps, lr_min, lr_max):
"""
Creates a step-based learning rate scheduler with a linear warmup phase from lr_min to lr_max,
followed by a cosine annealing phase from lr_max back down to lr_min.
"""
def lr_lambda(current_step):
# Linear warmup phase
if current_step < warmup_steps:
lr_range = lr_max - lr_min
lr = lr_min + lr_range * (current_step / warmup_steps)
return lr / lr_max
# Cosine annealing phase
else:
progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
lr_range = lr_max - lr_min
lr = lr_min + lr_range * cosine_decay
return lr / lr_max
return LambdaLR(optimizer, lr_lambda)
# --- Training and Validation Functions ---
def train_one_epoch(model, dataloader, optimizer, scheduler, device, epoch, args):
model.train()
total_loss = 0.0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Train]")
for batch in progress_bar:
optimizer.zero_grad()
x_1 = torch.tensor(batch['input_ids']).to(device)
batch_size, _ = x_1.shape
x_0 = torch.randint(0, model.vocab_size, x_1.shape, device=device)
t_continuous = torch.rand(batch_size, device=device)
mask = torch.rand(x_1.shape, device=device) < t_continuous.view(-1, 1)
x_t = torch.where(mask, x_1, x_0)
logits = model(x_t, t_continuous)
loss = F.cross_entropy(logits.view(-1, model.vocab_size), x_1.view(-1), label_smoothing=args.label_smoothing)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])
# wandb.log({"train_loss_step": loss.item(), "learning_rate": scheduler.get_last_lr()[0]})
return total_loss / len(dataloader)
def validate(model, val_dataloader, device, epoch, args):
"""
Performs validation, calculating NLL, Perplexity, and TC error.
"""
model.eval()
total_val_nll = 0.0
total_tc = 0.0
tc_batches = 0
progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch+1} [Val]")
with torch.no_grad():
for i, batch in enumerate(progress_bar):
x_1 = torch.tensor(batch['input_ids']).to(device)
batch_size, seq_len = x_1.shape
x_0 = torch.randint(0, model.vocab_size, x_1.shape, device=device)
t_continuous = torch.rand(batch_size, device=device)
mask = torch.rand(x_1.shape, device=device) < t_continuous.view(-1, 1)
x_t = torch.where(mask, x_1, x_0)
logits = model(x_t, t_continuous)
val_nll = F.cross_entropy(logits.view(-1, model.vocab_size), x_1.view(-1))
total_val_nll += val_nll.item()
if i < args.tc_batches:
k = args.tc_k_samples
p_marginal = F.softmax(logits, dim=-1)
sampled_x1 = torch.multinomial(p_marginal.view(-1, model.vocab_size), k, replacement=True).view(batch_size, seq_len, k)
kl_divs = []
for b in range(batch_size):
sample_tuples = [tuple(s.tolist()) for s in sampled_x1[b].T]
joint_counts = Counter(sample_tuples)
p_joint_est = {k: v / len(sample_tuples) for k, v in joint_counts.items()}
kl_sum = 0
for seq_tuple, p_j in p_joint_est.items():
log_p_marginal_prod = 0
for pos, token_id in enumerate(seq_tuple):
log_p_marginal_prod += torch.log(p_marginal[b, pos, token_id] + 1e-9)
kl_sum += p_j * (math.log(p_j + 1e-9) - log_p_marginal_prod)
kl_divs.append(kl_sum)
total_tc += sum(kl_divs) / len(kl_divs)
tc_batches += 1
avg_val_nll = total_val_nll / len(val_dataloader)
perplexity = math.exp(avg_val_nll)
avg_tc = total_tc / tc_batches if tc_batches > 0 else 0
return avg_val_nll, perplexity, avg_tc
# --- Main Execution ---
def main(args):
# try:
# wandb.login(key="811c943b63ebdf9409a9365602a39da3cfcf0062")
# except Exception as e:
# print(f"Could not log in to wandb: {e}")
# return
# wandb.init(project=args.wandb_project, name=f"lr{args.learning_rate}_wd{args.weight_decay}_layer{args.n_layers}_head{args.n_heads}_labelsmoothing{args.label_smoothing}", entity="programmablebio", config=args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
args.checkpoint_dir = args.checkpoint_dir + f"lr{args.learning_rate}_wd{args.weight_decay}_layer{args.n_layers}_head{args.n_heads}_labelsmoothing{args.label_smoothing}"
print(f"Saving to {args.checkpoint_dir}")
os.makedirs(args.checkpoint_dir, exist_ok=True)
print("Loading datasets...")
train_dataset = load_from_disk(args.train_dataset_path)
val_dataset = load_from_disk(args.val_dataset_path)
train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=None, shuffle=False)
print("Initializing model...")
model = MDLM(args.vocab_size, args.seq_len, args.model_dim, args.n_heads, args.n_layers).to(device)
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.")
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
num_training_steps = args.epochs * len(train_dataloader)
warmup_steps = int(num_training_steps * 0.1)
scheduler = get_lr_scheduler(optimizer, warmup_steps, num_training_steps, args.learning_rate * 0.1, args.learning_rate)
best_val_nll = float('inf')
print("Starting training...")
for epoch in range(args.epochs):
train_loss = train_one_epoch(model, train_dataloader, optimizer, scheduler, device, epoch, args)
val_nll, perplexity, tc_error = validate(model, val_dataloader, device, epoch, args)
print(f"Epoch {epoch+1}/{args.epochs} -> Train Loss: {train_loss:.4f}, Val NLL: {val_nll:.4f}, Val PPL: {perplexity:.2f}, TC: {tc_error:.4f}")
# wandb.log({
# "epoch": epoch + 1,
# "train_loss_epoch": train_loss,
# "val_nll_epoch": val_nll,
# "val_perplexity": perplexity,
# "conditional_total_correlation": tc_error,
# })
# checkpoint_path = os.path.join(args.checkpoint_dir, f"epoch_{epoch+1}.pt")
# torch.save({
# 'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
# 'val_nll': val_nll, 'args': args
# }, checkpoint_path)
# print(f"Checkpoint saved to {checkpoint_path}")
if val_nll < best_val_nll:
best_val_nll = val_nll
best_checkpoint_path = os.path.join(args.checkpoint_dir, "best_checkpoint.pt")
torch.save({
'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
'val_nll': val_nll, 'tc_error': tc_error, 'args': args
}, best_checkpoint_path)
print(f"New best checkpoint saved to {best_checkpoint_path} (Val NLL: {val_nll:.4f})")
# wandb.finish()
print("Training complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a ReDi (MDLM) model with self-contained evaluation.")
parser.add_argument("--train_dataset_path", type=str, required=True)
parser.add_argument("--val_dataset_path", type=str, required=True)
parser.add_argument("--model_dim", type=int, default=1024)
parser.add_argument("--n_heads", type=int, default=8)
parser.add_argument("--n_layers", type=int, default=6)
parser.add_argument("--vocab_size", type=int, default=24)
parser.add_argument("--seq_len", type=int, default=100)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-5)
parser.add_argument("--label_smoothing", type=float, default=0)
parser.add_argument("--tc_batches", type=int, default=20, help="Number of validation batches to use for TC calculation.")
parser.add_argument("--tc_k_samples", type=int, default=50, help="Number of samples (k) per data point for TC approximation.")
parser.add_argument("--wandb_project", type=str, default="redi-training")
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints")
args = parser.parse_args()
main(args)