AReUReDi / peptide /new_coupling.py
Tong Chen
add files
d2693e0
import argparse
import math
import os
from collections import defaultdict
import torch
import torch.nn as nn
from tqdm import tqdm
from datasets import Dataset, DatasetDict
# --- Model Architecture (Must match the trained model) ---
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, hidden_size, bias=True), nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
def forward(self, t):
return self.mlp(t.unsqueeze(-1))
class DiTBlock(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
x = x + gate_msa.unsqueeze(1) * attn_output
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
mlp_output = self.mlp(x_norm2)
x = x + gate_mlp.unsqueeze(1) * mlp_output
return x
class MDLM(nn.Module):
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.seq_len = seq_len
self.model_dim = model_dim
self.mask_token_id = vocab_size
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim)
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
self.time_embedder = TimestepEmbedder(model_dim)
self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)])
self.final_norm = nn.LayerNorm(model_dim)
self.lm_head = nn.Linear(model_dim, vocab_size)
def forward(self, x, t):
seq_len = x.shape[1]
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
t_embed = self.time_embedder(t)
for block in self.transformer_blocks:
x_embed = block(x_embed, t_embed)
x_embed = self.final_norm(x_embed)
logits = self.lm_head(x_embed)
return logits
# --- Generation & Utility Functions ---
def generate_x1_from_x0(model, device, x0_batch, steps, temperature):
model.eval()
x = x0_batch.clone()
num_samples, seq_len = x.shape
keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len
keep_schedule = torch.round(keep_schedule).long()
with torch.no_grad():
for i in range(steps):
t_continuous = torch.full((num_samples,), 1.0 - (i / steps), device=device)
logits = model(x, t_continuous)
scaled_logits = logits / temperature
probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(x.shape)
if i == steps - 1:
x = sampled_tokens
break
confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1)
num_to_keep = keep_schedule[i]
_, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1)
keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True)
x = torch.where(keep_mask, sampled_tokens, x)
return x
def is_sample_valid(sample_x1):
"""
Checks if special tokens [0, 1, 2, 3] appear in the middle of the sequence.
"""
middle_sequence = sample_x1[1:-1]
invalid_tokens = {0, 1, 2, 3}
for token in middle_sequence:
if token in invalid_tokens:
return False
return True
def create_prebatched_dataset(dataset, max_tokens_per_batch=500):
"""
Groups samples into batches and restructures the dataset.
Each row in the new dataset is a complete batch.
"""
# Group samples by their length
data_by_length = defaultdict(list)
for sample in dataset:
length = len(sample['input_ids_x1'])
data_by_length[length].append(sample)
# Create the actual batches
batched_data = {'input_ids_x0': [], 'input_ids_x1': []}
for length, samples in data_by_length.items():
samples_per_batch = max(1, max_tokens_per_batch // length)
for i in range(0, len(samples), samples_per_batch):
batch_samples = samples[i:i + samples_per_batch]
batch_x0 = [s['input_ids_x0'] for s in batch_samples]
batch_x1 = [s['input_ids_x1'] for s in batch_samples]
batched_data['input_ids_x0'].append(batch_x0)
batched_data['input_ids_x1'].append(batch_x1)
return Dataset.from_dict(batched_data)
# --- Main Execution ---
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Loading checkpoint from {args.checkpoint}...")
try:
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
model_args = checkpoint['args']
except Exception as e:
print(f"Error loading checkpoint: {e}")
return
print("Initializing model...")
model = MDLM(
vocab_size=model_args.vocab_size,
seq_len=model_args.seq_len,
model_dim=model_args.model_dim,
n_heads=model_args.n_heads,
n_layers=model_args.n_layers
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
print("Model loaded successfully.")
all_x0 = []
all_x1 = []
# 1. Generate samples for each length
for length in range(args.min_len, args.max_len + 1):
print(f"Generating {args.samples_per_len} valid samples for length {length}...")
valid_samples_count = 0
pbar = tqdm(total=args.samples_per_len)
while valid_samples_count < args.samples_per_len:
remaining = args.samples_per_len - valid_samples_count
batch_size = min(args.batch_size, remaining)
shape = (batch_size, length)
x0_batch = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device)
x1_batch = generate_x1_from_x0(model, device, x0_batch, args.gen_steps, args.temperature)
# 2. Perform sanity check on each sample
for x0, x1 in zip(x0_batch, x1_batch):
if is_sample_valid(x1.tolist()):
all_x0.append(x0.cpu().tolist())
all_x1.append(x1.cpu().tolist())
valid_samples_count += 1
pbar.update(1)
if valid_samples_count >= args.samples_per_len:
break
pbar.close()
# 3. Create dataset and split
print("Splitting dataset...")
rectified_data = {'input_ids_x0': all_x0, 'input_ids_x1': all_x1}
dataset = Dataset.from_dict(rectified_data)
train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
valid_test_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)
final_dataset_dict = DatasetDict({
'train': train_test_split['train'],
'validation': valid_test_split['train'],
'test': valid_test_split['test']
})
# 4. Pre-batch each split
print("Pre-batching splits...")
batched_dataset_dict = DatasetDict()
for split_name, split_dataset in final_dataset_dict.items():
print(f"Processing {split_name} split...")
batched_dataset_dict[split_name] = create_prebatched_dataset(split_dataset)
# 5. Save the final dataset
output_path = f"{args.output_path}/v{args.version}"
print(f"Saving new batched dataset to {output_path}...")
batched_dataset_dict.save_to_disk(output_path)
print("Rectification complete.")
print(f"Train on this by updating your training script's dataset path to '{output_path}'.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a rectified dataset with variable lengths and pre-batching.")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--output_path", type=str, default="./rectified_datasets")
parser.add_argument("--version", type=str, default='1')
parser.add_argument("--samples_per_len", type=int, default=10000)
parser.add_argument("--min_len", type=int, default=6)
parser.add_argument("--max_len", type=int, default=49)
parser.add_argument("--gen_steps", type=int, default=16)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--batch_size", type=int, default=128)
args = parser.parse_args()
main(args)