Upload 2 files
Browse files- prepare_gd_data.py +283 -0
- train_gd.py +623 -0
prepare_gd_data.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Prepare GD level data for modded-nanogpt training.
|
| 4 |
+
Converts tokenized levels to .bin format compatible with the data loader.
|
| 5 |
+
Uses multiprocessing for fast tokenization.
|
| 6 |
+
|
| 7 |
+
Usage (local files):
|
| 8 |
+
python prepare_gd_data.py --input data/gd_raw --output data/gd_levels --tokenizer tokenizer.model
|
| 9 |
+
|
| 10 |
+
Usage (HuggingFace dataset):
|
| 11 |
+
python prepare_gd_data.py --hf-repo tldne/gd-levels --hf-file levels_deduped.jsonl --output data/gd_levels --tokenizer tokenizer.model
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import numpy as np
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import sentencepiece as spm
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
from multiprocessing import Pool, cpu_count
|
| 22 |
+
from functools import partial
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
# Only include these fields in training (must match train_tokenizer.py!)
|
| 26 |
+
INCLUDE_FIELDS = [
|
| 27 |
+
'level_id', # Derived from source_file
|
| 28 |
+
'level_name',
|
| 29 |
+
'level_type',
|
| 30 |
+
'binary_version',
|
| 31 |
+
'description_decoded',
|
| 32 |
+
'song_id',
|
| 33 |
+
# level_string handled separately at the end
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def format_training_sample(record: dict) -> str:
|
| 38 |
+
"""
|
| 39 |
+
Format a level record as a training string.
|
| 40 |
+
Only includes specific metadata fields + level_string.
|
| 41 |
+
Must match train_tokenizer.py exactly!
|
| 42 |
+
"""
|
| 43 |
+
parts = []
|
| 44 |
+
|
| 45 |
+
# Extract level_id from source_file: "Level_12345.gmd2" -> "12345"
|
| 46 |
+
source_file = record.get('source_file', '')
|
| 47 |
+
if source_file:
|
| 48 |
+
level_id = source_file.lstrip("Level_").rstrip(".gmd2")
|
| 49 |
+
parts.append(f"<level_id>{level_id}")
|
| 50 |
+
|
| 51 |
+
# Add only the included metadata fields
|
| 52 |
+
for key in INCLUDE_FIELDS:
|
| 53 |
+
if key == 'level_id':
|
| 54 |
+
continue # Already handled above
|
| 55 |
+
value = record.get(key)
|
| 56 |
+
if value is not None and value != "":
|
| 57 |
+
parts.append(f"<{key}>{value}")
|
| 58 |
+
|
| 59 |
+
# Add level string at the end (the main content)
|
| 60 |
+
if record.get('level_string'):
|
| 61 |
+
parts.append(f"<level_string>{record['level_string']}")
|
| 62 |
+
|
| 63 |
+
return "".join(parts)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def write_bin_file(tokens: np.ndarray, output_path: Path):
|
| 67 |
+
"""Write tokens to .bin file with modded-nanogpt header format."""
|
| 68 |
+
header = np.zeros(256, dtype=np.int32)
|
| 69 |
+
header[0] = 20240520 # Magic number
|
| 70 |
+
header[1] = 1 # Version
|
| 71 |
+
header[2] = len(tokens)
|
| 72 |
+
|
| 73 |
+
with output_path.open("wb") as f:
|
| 74 |
+
f.write(header.tobytes())
|
| 75 |
+
f.write(tokens.astype(np.uint16).tobytes())
|
| 76 |
+
|
| 77 |
+
return len(tokens)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Worker function for multiprocessing
|
| 81 |
+
def tokenize_record(record_json: str, tokenizer_path: str) -> np.ndarray | None:
|
| 82 |
+
"""Tokenize a single record. Returns None if should be skipped."""
|
| 83 |
+
try:
|
| 84 |
+
# Load tokenizer in worker (each worker needs its own instance)
|
| 85 |
+
if not hasattr(tokenize_record, '_sp'):
|
| 86 |
+
tokenize_record._sp = spm.SentencePieceProcessor()
|
| 87 |
+
tokenize_record._sp.load(tokenizer_path)
|
| 88 |
+
sp = tokenize_record._sp
|
| 89 |
+
|
| 90 |
+
record = json.loads(record_json)
|
| 91 |
+
if not record.get('level_string'):
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
text = format_training_sample(record)
|
| 95 |
+
if not text or len(text) > 5_000_000:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
tokens = [sp.bos_id()] + sp.encode(text) + [sp.eos_id()]
|
| 99 |
+
if len(tokens) > 10_000_000:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
return np.array(tokens, dtype=np.uint16)
|
| 103 |
+
except:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def process_levels_parallel(
|
| 108 |
+
input_dir: Path,
|
| 109 |
+
output_dir: Path,
|
| 110 |
+
tokenizer_path: Path,
|
| 111 |
+
shard_size: int = 100_000_000,
|
| 112 |
+
val_ratio: float = 0.01,
|
| 113 |
+
num_workers: int = None,
|
| 114 |
+
):
|
| 115 |
+
"""Process all levels using multiprocessing and create train/val shards."""
|
| 116 |
+
|
| 117 |
+
if num_workers is None:
|
| 118 |
+
num_workers = max(1, cpu_count() - 1)
|
| 119 |
+
|
| 120 |
+
print(f"Using {num_workers} workers for tokenization")
|
| 121 |
+
|
| 122 |
+
# Load tokenizer to get token IDs
|
| 123 |
+
print(f"Loading tokenizer from {tokenizer_path}")
|
| 124 |
+
sp = spm.SentencePieceProcessor()
|
| 125 |
+
sp.load(str(tokenizer_path))
|
| 126 |
+
print(f"BOS ID: {sp.bos_id()}, EOS ID: {sp.eos_id()}")
|
| 127 |
+
|
| 128 |
+
# Find all input files
|
| 129 |
+
jsonl_files = list(input_dir.glob("*.jsonl"))
|
| 130 |
+
json_files = list(input_dir.glob("*.json"))
|
| 131 |
+
level_files = jsonl_files + json_files
|
| 132 |
+
|
| 133 |
+
if not level_files:
|
| 134 |
+
jsonl_files = list(input_dir.glob("**/*.jsonl"))
|
| 135 |
+
json_files = list(input_dir.glob("**/*.json"))
|
| 136 |
+
level_files = jsonl_files + json_files
|
| 137 |
+
|
| 138 |
+
print(f"Found {len(level_files)} input files")
|
| 139 |
+
|
| 140 |
+
# Collect all records as JSON strings
|
| 141 |
+
print("Loading records...")
|
| 142 |
+
all_records = []
|
| 143 |
+
for lf in tqdm(level_files, desc="Reading files"):
|
| 144 |
+
try:
|
| 145 |
+
if lf.suffix == ".jsonl":
|
| 146 |
+
with open(lf, "r", encoding="utf-8") as f:
|
| 147 |
+
for line in f:
|
| 148 |
+
line = line.strip()
|
| 149 |
+
if line:
|
| 150 |
+
all_records.append(line)
|
| 151 |
+
else:
|
| 152 |
+
with open(lf, "r", encoding="utf-8") as f:
|
| 153 |
+
data = json.load(f)
|
| 154 |
+
records = data if isinstance(data, list) else [data]
|
| 155 |
+
for record in records:
|
| 156 |
+
all_records.append(json.dumps(record))
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error reading {lf}: {e}")
|
| 159 |
+
|
| 160 |
+
print(f"Loaded {len(all_records):,} records")
|
| 161 |
+
|
| 162 |
+
# Tokenize in parallel
|
| 163 |
+
print(f"Tokenizing with {num_workers} workers...")
|
| 164 |
+
tokenize_fn = partial(tokenize_record, tokenizer_path=str(tokenizer_path))
|
| 165 |
+
|
| 166 |
+
all_levels = []
|
| 167 |
+
total_tokens = 0
|
| 168 |
+
|
| 169 |
+
with Pool(num_workers) as pool:
|
| 170 |
+
results = list(tqdm(
|
| 171 |
+
pool.imap(tokenize_fn, all_records, chunksize=100),
|
| 172 |
+
total=len(all_records),
|
| 173 |
+
desc="Tokenizing"
|
| 174 |
+
))
|
| 175 |
+
|
| 176 |
+
for tokens in results:
|
| 177 |
+
if tokens is not None:
|
| 178 |
+
all_levels.append(tokens)
|
| 179 |
+
total_tokens += len(tokens)
|
| 180 |
+
|
| 181 |
+
print(f"Tokenized {len(all_levels):,} levels with {total_tokens:,} total tokens")
|
| 182 |
+
|
| 183 |
+
# Shuffle and split
|
| 184 |
+
np.random.seed(42)
|
| 185 |
+
indices = np.random.permutation(len(all_levels))
|
| 186 |
+
|
| 187 |
+
val_count = max(1, int(len(all_levels) * val_ratio))
|
| 188 |
+
val_indices = indices[:val_count]
|
| 189 |
+
train_indices = indices[val_count:]
|
| 190 |
+
|
| 191 |
+
print(f"Train: {len(train_indices):,} levels, Val: {len(val_indices):,} levels")
|
| 192 |
+
|
| 193 |
+
# Create output directory
|
| 194 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
# Write train shards
|
| 197 |
+
train_tokens = np.concatenate([all_levels[i] for i in tqdm(train_indices, desc="Concatenating train")])
|
| 198 |
+
num_train_shards = max(1, len(train_tokens) // shard_size)
|
| 199 |
+
|
| 200 |
+
print(f"Writing {num_train_shards} train shards...")
|
| 201 |
+
shard_size_actual = len(train_tokens) // num_train_shards
|
| 202 |
+
for i in tqdm(range(num_train_shards), desc="Writing train shards"):
|
| 203 |
+
start = i * shard_size_actual
|
| 204 |
+
end = start + shard_size_actual if i < num_train_shards - 1 else len(train_tokens)
|
| 205 |
+
shard = train_tokens[start:end]
|
| 206 |
+
|
| 207 |
+
output_path = output_dir / f"train_{i:04d}.bin"
|
| 208 |
+
n = write_bin_file(shard, output_path)
|
| 209 |
+
|
| 210 |
+
# Write val shards
|
| 211 |
+
val_tokens = np.concatenate([all_levels[i] for i in tqdm(val_indices, desc="Concatenating val")])
|
| 212 |
+
num_val_shards = max(1, len(val_tokens) // shard_size)
|
| 213 |
+
|
| 214 |
+
print(f"Writing {num_val_shards} val shards...")
|
| 215 |
+
shard_size_actual = len(val_tokens) // num_val_shards
|
| 216 |
+
for i in range(num_val_shards):
|
| 217 |
+
start = i * shard_size_actual
|
| 218 |
+
end = start + shard_size_actual if i < num_val_shards - 1 else len(val_tokens)
|
| 219 |
+
shard = val_tokens[start:end]
|
| 220 |
+
|
| 221 |
+
output_path = output_dir / f"val_{i:04d}.bin"
|
| 222 |
+
n = write_bin_file(shard, output_path)
|
| 223 |
+
|
| 224 |
+
# Summary
|
| 225 |
+
print("\n=== Summary ===")
|
| 226 |
+
print(f"Train tokens: {len(train_tokens):,}")
|
| 227 |
+
print(f"Val tokens: {len(val_tokens):,}")
|
| 228 |
+
print(f"Total tokens: {len(train_tokens) + len(val_tokens):,}")
|
| 229 |
+
print(f"Output dir: {output_dir}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def main():
|
| 233 |
+
parser = argparse.ArgumentParser(description="Prepare GD levels for modded-nanogpt")
|
| 234 |
+
|
| 235 |
+
# Input options (either local or HF)
|
| 236 |
+
input_group = parser.add_mutually_exclusive_group(required=True)
|
| 237 |
+
input_group.add_argument("--input", "-i", type=Path, help="Input directory/file with level JSONs/JSONL")
|
| 238 |
+
input_group.add_argument("--hf-repo", type=str, help="HuggingFace repo ID (e.g., tldne/gd-levels)")
|
| 239 |
+
|
| 240 |
+
parser.add_argument("--hf-file", type=str, default="levels_deduped.jsonl", help="File to download from HF repo")
|
| 241 |
+
parser.add_argument("--output", "-o", type=Path, required=True, help="Output directory for .bin files")
|
| 242 |
+
parser.add_argument("--tokenizer", "-t", type=Path, required=True, help="Path to tokenizer.model")
|
| 243 |
+
parser.add_argument("--shard-size", type=int, default=100_000_000, help="Tokens per shard")
|
| 244 |
+
parser.add_argument("--val-ratio", type=float, default=0.01, help="Validation split ratio")
|
| 245 |
+
parser.add_argument("--workers", "-w", type=int, default=None, help="Number of workers (default: cpu_count - 1)")
|
| 246 |
+
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
|
| 249 |
+
# Handle HuggingFace download
|
| 250 |
+
if args.hf_repo:
|
| 251 |
+
print(f"Downloading {args.hf_file} from {args.hf_repo}...")
|
| 252 |
+
downloaded_path = hf_hub_download(
|
| 253 |
+
repo_id=args.hf_repo,
|
| 254 |
+
filename=args.hf_file,
|
| 255 |
+
repo_type="dataset",
|
| 256 |
+
)
|
| 257 |
+
print(f"Downloaded to: {downloaded_path}")
|
| 258 |
+
import tempfile
|
| 259 |
+
import shutil
|
| 260 |
+
temp_dir = Path(tempfile.mkdtemp())
|
| 261 |
+
shutil.copy(downloaded_path, temp_dir / args.hf_file)
|
| 262 |
+
input_dir = temp_dir
|
| 263 |
+
else:
|
| 264 |
+
input_dir = args.input
|
| 265 |
+
temp_dir = None
|
| 266 |
+
|
| 267 |
+
process_levels_parallel(
|
| 268 |
+
input_dir=input_dir,
|
| 269 |
+
output_dir=args.output,
|
| 270 |
+
tokenizer_path=args.tokenizer,
|
| 271 |
+
shard_size=args.shard_size,
|
| 272 |
+
val_ratio=args.val_ratio,
|
| 273 |
+
num_workers=args.workers,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Cleanup temp dir if created
|
| 277 |
+
if temp_dir:
|
| 278 |
+
import shutil
|
| 279 |
+
shutil.rmtree(temp_dir)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
main()
|
train_gd.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GD Level Training Script
|
| 2 |
+
# Based on modded-nanogpt train_gpt_medium.py
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
with open(sys.argv[0]) as f:
|
| 7 |
+
code = f.read()
|
| 8 |
+
import uuid
|
| 9 |
+
import time
|
| 10 |
+
import copy
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from functools import lru_cache
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import numpy as np
|
| 15 |
+
import wandb
|
| 16 |
+
|
| 17 |
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| 18 |
+
import torch
|
| 19 |
+
torch.empty(1, device="cuda", requires_grad=True).backward()
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.nn.attention.flex_attention import BlockMask, flex_attention
|
| 24 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
| 25 |
+
|
| 26 |
+
# -----------------------------------------------------------------------------
|
| 27 |
+
# Muon optimizer
|
| 28 |
+
|
| 29 |
+
def zeropower_via_newtonschulz5(G: Tensor) -> Tensor:
|
| 30 |
+
"""
|
| 31 |
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
|
| 32 |
+
"""
|
| 33 |
+
assert G.ndim >= 2
|
| 34 |
+
X = G.bfloat16()
|
| 35 |
+
if G.size(-2) > G.size(-1):
|
| 36 |
+
X = X.mT
|
| 37 |
+
|
| 38 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 39 |
+
for a, b, c in [
|
| 40 |
+
(4.0848, -6.8946, 2.9270),
|
| 41 |
+
(3.9505, -6.3029, 2.6377),
|
| 42 |
+
(3.7418, -5.5913, 2.3037),
|
| 43 |
+
(2.8769, -3.1427, 1.2046),
|
| 44 |
+
(2.8366, -3.0525, 1.2012),
|
| 45 |
+
]:
|
| 46 |
+
A = X @ X.mT
|
| 47 |
+
B = b * A + c * A @ A
|
| 48 |
+
X = a * X + B @ X
|
| 49 |
+
|
| 50 |
+
if G.size(-2) > G.size(-1):
|
| 51 |
+
X = X.mT
|
| 52 |
+
return X
|
| 53 |
+
|
| 54 |
+
@torch.compile
|
| 55 |
+
def update(acc_bf16_view_u16: Tensor, mantissa: Tensor, momentum_buffer: Tensor, grad: Tensor, momentum: Tensor, eff_lr: Tensor, eff_weight_decay: Tensor):
|
| 56 |
+
assert acc_bf16_view_u16.dtype == mantissa.dtype == torch.uint16
|
| 57 |
+
grad = grad.float()
|
| 58 |
+
momentum_buffer.copy_(momentum * momentum_buffer + (1 - momentum) * grad)
|
| 59 |
+
v = zeropower_via_newtonschulz5(momentum * momentum_buffer + (1 - momentum) * grad)
|
| 60 |
+
|
| 61 |
+
acc_m_u32 = (acc_bf16_view_u16.to(torch.uint32) << 16) | mantissa.to(torch.uint32)
|
| 62 |
+
acc_m_u32.view(torch.float32).mul_(1 - eff_weight_decay)
|
| 63 |
+
acc_m_u32.view(torch.float32).add_(other=v, alpha=-eff_lr)
|
| 64 |
+
acc_bf16_view_u16.copy_((acc_m_u32 >> 16).to(torch.uint16))
|
| 65 |
+
mantissa.copy_(acc_m_u32.to(torch.uint16))
|
| 66 |
+
|
| 67 |
+
class Muon(torch.optim.Optimizer):
|
| 68 |
+
"""Muon - MomentUm Orthogonalized by Newton-schulz"""
|
| 69 |
+
def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, rank=0, world_size=1):
|
| 70 |
+
self.rank = rank
|
| 71 |
+
self.world_size = world_size
|
| 72 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
| 73 |
+
super().__init__(params, defaults)
|
| 74 |
+
assert all(p.dtype == torch.bfloat16 for group in self.param_groups for p in group["params"])
|
| 75 |
+
|
| 76 |
+
@torch.no_grad()
|
| 77 |
+
def step(self):
|
| 78 |
+
futures: list[torch.Future] = []
|
| 79 |
+
for group in self.param_groups:
|
| 80 |
+
params: list[Tensor] = group["params"]
|
| 81 |
+
params_pad = params + [torch.empty_like(params[-1])] * self.world_size
|
| 82 |
+
momentum = torch._as_tensor_fullprec(group["momentum"])
|
| 83 |
+
for base_i in range(len(params))[::self.world_size]:
|
| 84 |
+
if base_i + self.rank < len(params):
|
| 85 |
+
p = params[base_i + self.rank]
|
| 86 |
+
state = self.state[p]
|
| 87 |
+
if len(state) == 0:
|
| 88 |
+
state["mantissa"] = torch.zeros_like(p, dtype=torch.uint16)
|
| 89 |
+
state["momentum_buffer"] = torch.zeros_like(p, dtype=torch.float32)
|
| 90 |
+
update(
|
| 91 |
+
p.view(torch.uint16), state["mantissa"], state["momentum_buffer"],
|
| 92 |
+
p.grad, momentum,
|
| 93 |
+
eff_lr=torch._as_tensor_fullprec(group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5),
|
| 94 |
+
eff_weight_decay=torch._as_tensor_fullprec(group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0)),
|
| 95 |
+
)
|
| 96 |
+
futures.append(dist.all_gather(params_pad[base_i:base_i + self.world_size], params_pad[base_i + self.rank], async_op=True).get_future())
|
| 97 |
+
torch.futures.collect_all(futures).wait()
|
| 98 |
+
|
| 99 |
+
# -----------------------------------------------------------------------------
|
| 100 |
+
# Model components
|
| 101 |
+
|
| 102 |
+
def norm(x: Tensor):
|
| 103 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def init_linear(w: Tensor):
|
| 107 |
+
std = 0.5 * (w.size(-1) ** -0.5)
|
| 108 |
+
bound = (3 ** 0.5) * std
|
| 109 |
+
return w.uniform_(-bound, bound)
|
| 110 |
+
|
| 111 |
+
class Rotary(nn.Module):
|
| 112 |
+
def __init__(self, dim: int, max_seq_len: int):
|
| 113 |
+
super().__init__()
|
| 114 |
+
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
|
| 115 |
+
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
|
| 116 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 117 |
+
theta = torch.einsum("i,j -> ij", t, angular_freq)
|
| 118 |
+
self.cos = nn.Buffer(theta.cos(), persistent=False)
|
| 119 |
+
self.sin = nn.Buffer(theta.sin(), persistent=False)
|
| 120 |
+
|
| 121 |
+
def forward(self, x_BTHD: Tensor):
|
| 122 |
+
assert self.cos.size(0) >= x_BTHD.size(-3)
|
| 123 |
+
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
|
| 124 |
+
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
|
| 125 |
+
y1 = x1 * cos + x2 * sin
|
| 126 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 127 |
+
return torch.cat((y1, y2), 3).type_as(x_BTHD)
|
| 128 |
+
|
| 129 |
+
class CausalSelfAttention(nn.Module):
|
| 130 |
+
def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.num_heads = num_heads
|
| 133 |
+
self.head_dim = head_dim
|
| 134 |
+
hdim = num_heads * head_dim
|
| 135 |
+
self.qkvo_w = nn.Parameter(init_linear(torch.empty(4, hdim, dim)).bfloat16())
|
| 136 |
+
self.qkvo_w.detach()[3].zero_()
|
| 137 |
+
self.rotary = Rotary(head_dim, max_seq_len)
|
| 138 |
+
self.attn_scale = 0.12
|
| 139 |
+
|
| 140 |
+
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask, lambdas: Tensor):
|
| 141 |
+
B, T = x.size(0), x.size(1)
|
| 142 |
+
assert B == 1, "Must use batch size = 1 for FlexAttention"
|
| 143 |
+
q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
|
| 144 |
+
q, k = norm(q), norm(k)
|
| 145 |
+
q, k = self.rotary(q), self.rotary(k)
|
| 146 |
+
v = norm(v)
|
| 147 |
+
if ve is not None:
|
| 148 |
+
v = lambdas[0] * v + lambdas[1] * ve.view_as(v)
|
| 149 |
+
else:
|
| 150 |
+
v = lambdas[0] * v
|
| 151 |
+
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2)
|
| 152 |
+
y = y.contiguous().view(B, T, self.num_heads * self.head_dim)
|
| 153 |
+
y = F.linear(y, self.qkvo_w[3])
|
| 154 |
+
return y
|
| 155 |
+
|
| 156 |
+
class MLP(nn.Module):
|
| 157 |
+
def __init__(self, dim: int):
|
| 158 |
+
super().__init__()
|
| 159 |
+
hdim = 4 * dim
|
| 160 |
+
self.fc_w = nn.Parameter(init_linear(torch.empty(hdim, dim)).bfloat16())
|
| 161 |
+
self.proj_w = nn.Parameter(torch.zeros(dim, hdim).bfloat16())
|
| 162 |
+
self.fc_w.wd_mul = 2.0
|
| 163 |
+
self.proj_w.wd_mul = 2.0
|
| 164 |
+
|
| 165 |
+
def forward(self, x: Tensor):
|
| 166 |
+
x = F.linear(x, self.fc_w)
|
| 167 |
+
x = F.relu(x).square()
|
| 168 |
+
x = F.linear(x, self.proj_w)
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
class Block(nn.Module):
|
| 172 |
+
def __init__(self, dim: int, num_heads: int, max_seq_len: int):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.attn = CausalSelfAttention(dim, num_heads, max_seq_len)
|
| 175 |
+
self.mlp = MLP(dim)
|
| 176 |
+
|
| 177 |
+
def forward(self, x: Tensor, ve: Tensor | None, x00: Tensor, x01: Tensor, block_mask: BlockMask, lambdas: Tensor, sa_lambdas: Tensor):
|
| 178 |
+
x = lambdas[0] * x + lambdas[1] * x00 + lambdas[2] * x01
|
| 179 |
+
x = x + self.attn(x, ve, block_mask, sa_lambdas)
|
| 180 |
+
x = x + self.mlp(norm(x))
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
# -----------------------------------------------------------------------------
|
| 184 |
+
# Main model
|
| 185 |
+
|
| 186 |
+
def next_multiple_of_n(v: float | int, *, n: int):
|
| 187 |
+
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
|
| 188 |
+
|
| 189 |
+
class GPT(nn.Module):
|
| 190 |
+
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, eos_token_id: int = 3):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.eos_token_id = eos_token_id
|
| 193 |
+
self.embed1 = nn.Embedding(vocab_size, model_dim)
|
| 194 |
+
self.embed2 = nn.Embedding(vocab_size, model_dim)
|
| 195 |
+
# 5 value embeddings (proven to help convergence)
|
| 196 |
+
self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(5)])
|
| 197 |
+
self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len) for _ in range(num_layers)])
|
| 198 |
+
self.lm_head_w = nn.Parameter(torch.zeros(next_multiple_of_n(vocab_size, n=128), model_dim))
|
| 199 |
+
assert num_layers % 2 == 0
|
| 200 |
+
self.scalars = nn.Parameter(torch.cat([
|
| 201 |
+
torch.ones(num_layers),
|
| 202 |
+
*[torch.tensor([1.0, 0.0, 0.0]) for _ in range(num_layers)],
|
| 203 |
+
*[torch.tensor([0.5, 0.5]) for _ in range(num_layers)],
|
| 204 |
+
]))
|
| 205 |
+
|
| 206 |
+
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
|
| 207 |
+
BLOCK_SIZE = 128
|
| 208 |
+
docs = (input_seq == self.eos_token_id).cumsum(0)
|
| 209 |
+
|
| 210 |
+
def document_causal(b, h, q_idx, kv_idx):
|
| 211 |
+
causal_mask = q_idx >= kv_idx
|
| 212 |
+
document_mask = docs[q_idx] == docs[kv_idx]
|
| 213 |
+
return causal_mask & document_mask
|
| 214 |
+
|
| 215 |
+
def dense_to_ordered(dense_blockmask: Tensor):
|
| 216 |
+
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
|
| 217 |
+
indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
|
| 218 |
+
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
|
| 219 |
+
|
| 220 |
+
assert len(input_seq) % BLOCK_SIZE == 0
|
| 221 |
+
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
|
| 222 |
+
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
|
| 223 |
+
causal_blockmask_any = block_idx[:, None] >= block_idx
|
| 224 |
+
causal_blockmask_all = block_idx[:, None] > block_idx
|
| 225 |
+
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
|
| 226 |
+
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
|
| 227 |
+
document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
|
| 228 |
+
document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
|
| 229 |
+
blockmask_any = causal_blockmask_any & document_blockmask_any
|
| 230 |
+
blockmask_all = causal_blockmask_all & document_blockmask_all
|
| 231 |
+
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
|
| 232 |
+
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
|
| 233 |
+
def build_bm(window_size_blocks: Tensor) -> BlockMask:
|
| 234 |
+
return BlockMask.from_kv_blocks(
|
| 235 |
+
torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
|
| 236 |
+
partial_kv_indices,
|
| 237 |
+
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
|
| 238 |
+
full_kv_indices,
|
| 239 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 240 |
+
mask_mod=document_causal,
|
| 241 |
+
)
|
| 242 |
+
return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
|
| 243 |
+
|
| 244 |
+
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
|
| 245 |
+
assert input_seq.ndim == 1
|
| 246 |
+
L = len(self.blocks)
|
| 247 |
+
|
| 248 |
+
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
|
| 249 |
+
# U-net pattern for 24 layers: 0-4 and 19-23
|
| 250 |
+
ve_layers = [ve[0], ve[1], ve[2], ve[3], ve[4]] + [None] * (L - 10) + [ve[0], ve[1], ve[2], ve[3], ve[4]]
|
| 251 |
+
assert len(ve_layers) == L
|
| 252 |
+
|
| 253 |
+
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
|
| 254 |
+
# Distribute long/short attention: every 4th layer gets long
|
| 255 |
+
block_masks = [long_bm if i % 4 == 0 else short_bm for i in range(L)]
|
| 256 |
+
|
| 257 |
+
x = x00 = norm(self.embed1(input_seq)[None])
|
| 258 |
+
x01 = norm(self.embed2(input_seq)[None])
|
| 259 |
+
|
| 260 |
+
# Skip connections - Option B: +4 gap ladder, later injection, avoids long-attn destinations
|
| 261 |
+
# Gaps: 7, 11, 15 (+4 each). Source layer 8 is long-attn, giving later layers wider receptive field.
|
| 262 |
+
skip_connections = []
|
| 263 |
+
skip_map = {
|
| 264 |
+
15: 8, # gap 7
|
| 265 |
+
17: 6, # gap 11
|
| 266 |
+
19: 4, # gap 15
|
| 267 |
+
}
|
| 268 |
+
skip_weights = self.scalars[:L]
|
| 269 |
+
lambdas = self.scalars[1 * L: 4 * L].view(-1, 3)
|
| 270 |
+
sa_lambdas = self.scalars[4 * L: 6 * L].view(-1, 2)
|
| 271 |
+
|
| 272 |
+
for i in range(L):
|
| 273 |
+
if i in skip_map:
|
| 274 |
+
x = x + skip_weights[skip_map[i]] * skip_connections[skip_map[i]]
|
| 275 |
+
x = self.blocks[i](x, ve_layers[i], x00, x01, block_masks[i], lambdas[i], sa_lambdas[i])
|
| 276 |
+
skip_connections.append(x)
|
| 277 |
+
|
| 278 |
+
x = norm(x)
|
| 279 |
+
if self.training:
|
| 280 |
+
logits: Tensor = F.linear(x.flatten(end_dim=1), self.lm_head_w.bfloat16()).float()
|
| 281 |
+
loss = F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq)
|
| 282 |
+
return loss
|
| 283 |
+
|
| 284 |
+
loss = 0
|
| 285 |
+
for i in range(4):
|
| 286 |
+
logits: Tensor = F.linear(x.flatten(end_dim=1).chunk(4)[i], self.lm_head_w.bfloat16()).float()
|
| 287 |
+
loss += F.cross_entropy(15 * logits * torch.rsqrt(logits.square() + 225), target_seq.chunk(4)[i]) / 4
|
| 288 |
+
return loss
|
| 289 |
+
|
| 290 |
+
# -----------------------------------------------------------------------------
|
| 291 |
+
# Data loading
|
| 292 |
+
|
| 293 |
+
def _load_data_shard(file: Path):
|
| 294 |
+
header = torch.from_file(str(file), False, 256, dtype=torch.int32)
|
| 295 |
+
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
|
| 296 |
+
assert header[1] == 1, "unsupported version"
|
| 297 |
+
num_tokens = int(header[2])
|
| 298 |
+
with file.open("rb", buffering=0) as f:
|
| 299 |
+
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
|
| 300 |
+
f.seek(256 * 4)
|
| 301 |
+
nbytes = f.readinto(tokens.numpy())
|
| 302 |
+
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
|
| 303 |
+
return tokens
|
| 304 |
+
|
| 305 |
+
def distributed_data_generator(filename_pattern: str, batch_size: int, rank: int, world_size: int):
|
| 306 |
+
files = sorted(Path.cwd().glob(filename_pattern))
|
| 307 |
+
assert batch_size % world_size == 0
|
| 308 |
+
local_batch_size = batch_size // world_size
|
| 309 |
+
|
| 310 |
+
epoch = 0
|
| 311 |
+
while True:
|
| 312 |
+
# Shuffle files each epoch (deterministic per epoch for reproducibility)
|
| 313 |
+
rng = np.random.default_rng(seed=42 + epoch)
|
| 314 |
+
shuffled_files = rng.permutation(files).tolist()
|
| 315 |
+
|
| 316 |
+
for file in shuffled_files:
|
| 317 |
+
tokens = _load_data_shard(file)
|
| 318 |
+
pos = 0
|
| 319 |
+
while pos + batch_size + 1 < len(tokens):
|
| 320 |
+
buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
|
| 321 |
+
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True)
|
| 322 |
+
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True)
|
| 323 |
+
pos += batch_size
|
| 324 |
+
yield inputs, targets
|
| 325 |
+
|
| 326 |
+
epoch += 1
|
| 327 |
+
if rank == 0:
|
| 328 |
+
print(f"Completed epoch {epoch}, shuffling for next epoch...")
|
| 329 |
+
|
| 330 |
+
# -----------------------------------------------------------------------------
|
| 331 |
+
# Hyperparameters
|
| 332 |
+
|
| 333 |
+
@dataclass
|
| 334 |
+
class Hyperparameters:
|
| 335 |
+
# Data paths
|
| 336 |
+
train_files = "data/gd_levels/train_*.bin"
|
| 337 |
+
val_files = "data/gd_levels/val_*.bin"
|
| 338 |
+
val_tokens = 10420224 # Must be divisible by (num_gpus × val_seq_len) = 6 × 16k = 98304
|
| 339 |
+
|
| 340 |
+
# Sequence lengths (reduced for 6-GPU setup)
|
| 341 |
+
train_seq_len = 16 * 1024 # 16k context
|
| 342 |
+
val_seq_len = 16 * 1024 # 16k for validation too
|
| 343 |
+
|
| 344 |
+
# Training (6 GPUs × 16k = 98,304 tokens/step)
|
| 345 |
+
num_iterations = 109063 # 10.72B tokens / 98,304 tokens per step (exact)
|
| 346 |
+
cooldown_frac = 0.7 # Matching Medium - 70% of training in LR decay
|
| 347 |
+
|
| 348 |
+
# Architecture
|
| 349 |
+
vocab_size = 32000
|
| 350 |
+
num_layers = 24
|
| 351 |
+
num_heads = 10 # 1280 / 128
|
| 352 |
+
model_dim = 1280
|
| 353 |
+
eos_token_id = 3 # Your tokenizer's EOS
|
| 354 |
+
|
| 355 |
+
# Logging and checkpointing
|
| 356 |
+
val_loss_every = 5000 # Calculate val_loss every 5000 steps
|
| 357 |
+
wandb_log_every = 1 # Log training metrics to wandb every 100 steps
|
| 358 |
+
save_every = 10000 # Save checkpoint every 10k steps (~11 checkpoints)
|
| 359 |
+
save_checkpoint = True
|
| 360 |
+
resume_from = None # Set to checkpoint path or use RESUME_FROM env var
|
| 361 |
+
|
| 362 |
+
args = Hyperparameters()
|
| 363 |
+
# Allow env var override for resume
|
| 364 |
+
if os.environ.get("RESUME_FROM"):
|
| 365 |
+
args.resume_from = os.environ["RESUME_FROM"]
|
| 366 |
+
|
| 367 |
+
# -----------------------------------------------------------------------------
|
| 368 |
+
# Training setup
|
| 369 |
+
|
| 370 |
+
run_id = int(os.environ.get("RUN_ID", 0))
|
| 371 |
+
rank = int(os.environ["RANK"])
|
| 372 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 373 |
+
assert torch.cuda.is_available()
|
| 374 |
+
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
|
| 375 |
+
torch.cuda.set_device(device)
|
| 376 |
+
dist.init_process_group(backend="nccl", device_id=device)
|
| 377 |
+
dist.barrier()
|
| 378 |
+
master_process = (rank == 0)
|
| 379 |
+
|
| 380 |
+
if master_process:
|
| 381 |
+
run_id_full = f"{run_id:03d}_{uuid.uuid4()}"
|
| 382 |
+
os.makedirs("logs", exist_ok=True)
|
| 383 |
+
logfile = f"logs/{run_id_full}.txt"
|
| 384 |
+
print(logfile)
|
| 385 |
+
# Initialize wandb
|
| 386 |
+
wandb.init(
|
| 387 |
+
project="gd-level-generation",
|
| 388 |
+
name=run_id_full,
|
| 389 |
+
config={
|
| 390 |
+
"vocab_size": args.vocab_size,
|
| 391 |
+
"num_layers": args.num_layers,
|
| 392 |
+
"model_dim": args.model_dim,
|
| 393 |
+
"num_heads": args.num_heads,
|
| 394 |
+
"train_seq_len": args.train_seq_len,
|
| 395 |
+
"num_iterations": args.num_iterations,
|
| 396 |
+
"cooldown_frac": args.cooldown_frac,
|
| 397 |
+
},
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
def print0(s, console=False):
|
| 401 |
+
if master_process:
|
| 402 |
+
with open(logfile, "a") as f:
|
| 403 |
+
if console:
|
| 404 |
+
print(s)
|
| 405 |
+
print(s, file=f)
|
| 406 |
+
|
| 407 |
+
print0(code)
|
| 408 |
+
print0("=" * 100)
|
| 409 |
+
print0(f"Running Python {sys.version}")
|
| 410 |
+
print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
|
| 411 |
+
|
| 412 |
+
def nvidia_smi():
|
| 413 |
+
import subprocess
|
| 414 |
+
return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
|
| 415 |
+
print0(nvidia_smi())
|
| 416 |
+
print0("=" * 100)
|
| 417 |
+
|
| 418 |
+
# -----------------------------------------------------------------------------
|
| 419 |
+
# Model and optimizer
|
| 420 |
+
|
| 421 |
+
model: nn.Module = GPT(
|
| 422 |
+
vocab_size=args.vocab_size,
|
| 423 |
+
num_layers=args.num_layers,
|
| 424 |
+
num_heads=args.num_heads,
|
| 425 |
+
model_dim=args.model_dim,
|
| 426 |
+
max_seq_len=max(args.train_seq_len, args.val_seq_len),
|
| 427 |
+
eos_token_id=args.eos_token_id,
|
| 428 |
+
).cuda()
|
| 429 |
+
|
| 430 |
+
for m in model.modules():
|
| 431 |
+
if isinstance(m, nn.Embedding):
|
| 432 |
+
m.bfloat16()
|
| 433 |
+
for param in model.parameters():
|
| 434 |
+
dist.broadcast(param.detach(), 0)
|
| 435 |
+
|
| 436 |
+
# Print param count
|
| 437 |
+
if master_process:
|
| 438 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 439 |
+
print0(f"Total parameters: {total_params:,} ({total_params/1e6:.1f}M)", console=True)
|
| 440 |
+
|
| 441 |
+
# Collect parameters
|
| 442 |
+
hidden_matrix_params = sorted((p for p in model.blocks.parameters() if p.ndim >= 2), key=lambda x: x.size(), reverse=True)
|
| 443 |
+
embed_params = [*model.embed1.parameters(), *model.embed2.parameters(), *model.value_embeds.parameters()]
|
| 444 |
+
scalar_params = [model.scalars]
|
| 445 |
+
head_params: list[nn.Parameter] = [model.lm_head_w]
|
| 446 |
+
|
| 447 |
+
params_collections = [hidden_matrix_params, embed_params, scalar_params, head_params]
|
| 448 |
+
optimized_parameters_set = {p for params in params_collections for p in params}
|
| 449 |
+
assert optimized_parameters_set == {*model.parameters()}
|
| 450 |
+
assert len(optimized_parameters_set) == sum(len(lst) for lst in params_collections)
|
| 451 |
+
|
| 452 |
+
# Optimizers
|
| 453 |
+
adam_param_groups = [
|
| 454 |
+
dict(params=head_params, lr=1/320),
|
| 455 |
+
dict(params=embed_params, lr=0.3),
|
| 456 |
+
dict(params=scalar_params, lr=0.015),
|
| 457 |
+
]
|
| 458 |
+
optimizer1 = torch.optim.AdamW(adam_param_groups, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0, fused=True)
|
| 459 |
+
optimizer2 = Muon(hidden_matrix_params, lr=0.025, momentum=0.95, rank=rank, world_size=world_size)
|
| 460 |
+
optimizers: list[torch.optim.Optimizer] = [optimizer1, optimizer2]
|
| 461 |
+
|
| 462 |
+
def opt_params(opt: torch.optim.Optimizer) -> list[nn.Parameter]:
|
| 463 |
+
return [p for group in opt.param_groups for p in group["params"]]
|
| 464 |
+
opt2params = {opt: opt_params(opt) for opt in optimizers}
|
| 465 |
+
for opt in optimizers:
|
| 466 |
+
for group in opt.param_groups:
|
| 467 |
+
group["initial_lr"] = group["lr"]
|
| 468 |
+
|
| 469 |
+
# Resume from checkpoint if specified
|
| 470 |
+
start_step = 0
|
| 471 |
+
if args.resume_from:
|
| 472 |
+
print0(f"Resuming from checkpoint: {args.resume_from}", console=True)
|
| 473 |
+
checkpoint = torch.load(args.resume_from, map_location=device)
|
| 474 |
+
# Load model state (handle torch.compile prefix)
|
| 475 |
+
model_state = checkpoint["model"]
|
| 476 |
+
if any(k.startswith("_orig_mod.") for k in model_state.keys()):
|
| 477 |
+
model_state = {k.replace("_orig_mod.", ""): v for k, v in model_state.items()}
|
| 478 |
+
model.load_state_dict(model_state)
|
| 479 |
+
# Load optimizer states
|
| 480 |
+
for opt, opt_state in zip(optimizers, checkpoint["optimizers"]):
|
| 481 |
+
opt.load_state_dict(opt_state)
|
| 482 |
+
start_step = checkpoint["step"] + 1
|
| 483 |
+
print0(f"Resumed from step {checkpoint['step']}, continuing from step {start_step}", console=True)
|
| 484 |
+
del checkpoint
|
| 485 |
+
|
| 486 |
+
# LR schedule
|
| 487 |
+
def get_lr(step: int):
|
| 488 |
+
x = step / args.num_iterations
|
| 489 |
+
assert 0 <= x < 1
|
| 490 |
+
if x < 1 - args.cooldown_frac:
|
| 491 |
+
return 1.0
|
| 492 |
+
else:
|
| 493 |
+
return (1 - x) / args.cooldown_frac
|
| 494 |
+
|
| 495 |
+
# Window size schedule
|
| 496 |
+
@lru_cache(1)
|
| 497 |
+
def get_window_size_blocks_helper(window_size: int):
|
| 498 |
+
return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 499 |
+
|
| 500 |
+
def get_window_size_blocks(step: int):
|
| 501 |
+
x = step / args.num_iterations
|
| 502 |
+
assert 0 <= x <= 1
|
| 503 |
+
# Cubic schedule: 0 → 3456 (matching Medium)
|
| 504 |
+
factor = 4 * x ** 3 - 6 * x ** 2 + 3 * x
|
| 505 |
+
window_size = next_multiple_of_n(3456 * factor, n=128)
|
| 506 |
+
return get_window_size_blocks_helper(window_size)
|
| 507 |
+
|
| 508 |
+
model: nn.Module = torch.compile(model, dynamic=False)
|
| 509 |
+
|
| 510 |
+
# -----------------------------------------------------------------------------
|
| 511 |
+
# Warmup kernels
|
| 512 |
+
|
| 513 |
+
warmup_steps = 10
|
| 514 |
+
initial_state = copy.deepcopy(dict(model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]))
|
| 515 |
+
for warmup_step in range(warmup_steps):
|
| 516 |
+
print0(f"Warmup step {warmup_step+1}/{warmup_steps}")
|
| 517 |
+
inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda")
|
| 518 |
+
model(inputs.to(torch.int32), targets, get_window_size_blocks(0)).backward()
|
| 519 |
+
for param in model.parameters():
|
| 520 |
+
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
| 521 |
+
for opt in optimizers:
|
| 522 |
+
opt.step()
|
| 523 |
+
model.zero_grad(set_to_none=True)
|
| 524 |
+
model.load_state_dict(initial_state["model"])
|
| 525 |
+
for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
|
| 526 |
+
opt.load_state_dict(opt_state)
|
| 527 |
+
del initial_state
|
| 528 |
+
|
| 529 |
+
# -----------------------------------------------------------------------------
|
| 530 |
+
# Training loop
|
| 531 |
+
|
| 532 |
+
torch.cuda.reset_peak_memory_stats()
|
| 533 |
+
train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size)
|
| 534 |
+
training_time_ms = 0
|
| 535 |
+
dist.barrier()
|
| 536 |
+
t0 = time.perf_counter()
|
| 537 |
+
|
| 538 |
+
train_steps = args.num_iterations
|
| 539 |
+
for step in range(start_step, train_steps + 1):
|
| 540 |
+
last_step = (step == train_steps)
|
| 541 |
+
|
| 542 |
+
# Validation
|
| 543 |
+
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
|
| 544 |
+
dist.barrier()
|
| 545 |
+
training_time_ms += 1000 * (time.perf_counter() - t0)
|
| 546 |
+
model.eval()
|
| 547 |
+
val_batch_size = world_size * args.val_seq_len
|
| 548 |
+
assert args.val_tokens % val_batch_size == 0
|
| 549 |
+
val_steps = args.val_tokens // val_batch_size
|
| 550 |
+
val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
|
| 551 |
+
val_loss = 0
|
| 552 |
+
with torch.no_grad():
|
| 553 |
+
for _ in range(val_steps):
|
| 554 |
+
inputs, targets = next(val_loader)
|
| 555 |
+
val_loss += model(inputs, targets, get_window_size_blocks(step))
|
| 556 |
+
val_loss /= val_steps
|
| 557 |
+
del val_loader
|
| 558 |
+
dist.reduce(val_loss, 0, op=dist.ReduceOp.AVG)
|
| 559 |
+
print0(f"step:{step}/{train_steps} val_loss:{val_loss:.6f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True)
|
| 560 |
+
|
| 561 |
+
# Log to wandb
|
| 562 |
+
if master_process:
|
| 563 |
+
wandb.log({
|
| 564 |
+
"val_loss": val_loss.item() if hasattr(val_loss, 'item') else val_loss,
|
| 565 |
+
"step": step,
|
| 566 |
+
"train_time_ms": training_time_ms,
|
| 567 |
+
"step_avg_ms": training_time_ms / max(step, 1),
|
| 568 |
+
"lr_mult": get_lr(step) if step < train_steps else 0,
|
| 569 |
+
})
|
| 570 |
+
|
| 571 |
+
# Save checkpoint during training (for spot instance resilience)
|
| 572 |
+
if master_process and args.save_checkpoint and step > 0 and step % args.save_every == 0:
|
| 573 |
+
log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
|
| 574 |
+
os.makedirs(f"logs/{run_id_full}", exist_ok=True)
|
| 575 |
+
torch.save(log, f"logs/{run_id_full}/state_step{step:06d}.pt")
|
| 576 |
+
print0(f"Saved checkpoint at step {step}", console=True)
|
| 577 |
+
|
| 578 |
+
model.train()
|
| 579 |
+
dist.barrier()
|
| 580 |
+
t0 = time.perf_counter()
|
| 581 |
+
|
| 582 |
+
if last_step:
|
| 583 |
+
if master_process and args.save_checkpoint:
|
| 584 |
+
log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
|
| 585 |
+
os.makedirs(f"logs/{run_id_full}", exist_ok=True)
|
| 586 |
+
torch.save(log, f"logs/{run_id_full}/state_step{step:06d}.pt")
|
| 587 |
+
break
|
| 588 |
+
|
| 589 |
+
# Training step
|
| 590 |
+
inputs, targets = next(train_loader)
|
| 591 |
+
train_loss = model(inputs, targets, get_window_size_blocks(step))
|
| 592 |
+
train_loss.backward()
|
| 593 |
+
opt2futures = {
|
| 594 |
+
opt: [dist.all_reduce(p.grad, op=dist.ReduceOp.AVG, async_op=True).get_future() for p in params]
|
| 595 |
+
for opt, params in opt2params.items()
|
| 596 |
+
}
|
| 597 |
+
for opt in optimizers:
|
| 598 |
+
for group in opt.param_groups:
|
| 599 |
+
group["lr"] = group["initial_lr"] * get_lr(step)
|
| 600 |
+
for group in optimizer2.param_groups:
|
| 601 |
+
frac = min(step / 300, 1)
|
| 602 |
+
group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
|
| 603 |
+
for opt in optimizers:
|
| 604 |
+
torch.futures.collect_all(opt2futures[opt]).wait()
|
| 605 |
+
opt.step()
|
| 606 |
+
model.zero_grad(set_to_none=True)
|
| 607 |
+
|
| 608 |
+
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
|
| 609 |
+
print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True)
|
| 610 |
+
|
| 611 |
+
# Log to wandb every N steps (lightweight, no val loss calc)
|
| 612 |
+
if master_process and step % args.wandb_log_every == 0:
|
| 613 |
+
wandb.log({
|
| 614 |
+
"train_loss": train_loss.item(),
|
| 615 |
+
"step": step,
|
| 616 |
+
"train_time_ms": approx_training_time_ms,
|
| 617 |
+
"step_avg_ms": approx_training_time_ms / (step + 1),
|
| 618 |
+
"lr_mult": get_lr(step),
|
| 619 |
+
}, step=step)
|
| 620 |
+
|
| 621 |
+
print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
|
| 622 |
+
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
|
| 623 |
+
dist.destroy_process_group()
|