stuff / prepare_gd_data.py
NobodyExistsOnTheInternet's picture
Upload 2 files
c9a06e3 verified
#!/usr/bin/env python3
"""
Prepare GD level data for modded-nanogpt training.
Converts tokenized levels to .bin format compatible with the data loader.
Uses multiprocessing for fast tokenization.
Usage (local files):
python prepare_gd_data.py --input data/gd_raw --output data/gd_levels --tokenizer tokenizer.model
Usage (HuggingFace dataset):
python prepare_gd_data.py --hf-repo tldne/gd-levels --hf-file levels_deduped.jsonl --output data/gd_levels --tokenizer tokenizer.model
"""
import argparse
import json
import numpy as np
from pathlib import Path
import sentencepiece as spm
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from multiprocessing import Pool, cpu_count
from functools import partial
import os
# Only include these fields in training (must match train_tokenizer.py!)
INCLUDE_FIELDS = [
'level_id', # Derived from source_file
'level_name',
'level_type',
'binary_version',
'description_decoded',
'song_id',
# level_string handled separately at the end
]
def format_training_sample(record: dict) -> str:
"""
Format a level record as a training string.
Only includes specific metadata fields + level_string.
Must match train_tokenizer.py exactly!
"""
parts = []
# Extract level_id from source_file: "Level_12345.gmd2" -> "12345"
source_file = record.get('source_file', '')
if source_file:
level_id = source_file.lstrip("Level_").rstrip(".gmd2")
parts.append(f"<level_id>{level_id}")
# Add only the included metadata fields
for key in INCLUDE_FIELDS:
if key == 'level_id':
continue # Already handled above
value = record.get(key)
if value is not None and value != "":
parts.append(f"<{key}>{value}")
# Add level string at the end (the main content)
if record.get('level_string'):
parts.append(f"<level_string>{record['level_string']}")
return "".join(parts)
def write_bin_file(tokens: np.ndarray, output_path: Path):
"""Write tokens to .bin file with modded-nanogpt header format."""
header = np.zeros(256, dtype=np.int32)
header[0] = 20240520 # Magic number
header[1] = 1 # Version
header[2] = len(tokens)
with output_path.open("wb") as f:
f.write(header.tobytes())
f.write(tokens.astype(np.uint16).tobytes())
return len(tokens)
# Worker function for multiprocessing
def tokenize_record(record_json: str, tokenizer_path: str) -> np.ndarray | None:
"""Tokenize a single record. Returns None if should be skipped."""
try:
# Load tokenizer in worker (each worker needs its own instance)
if not hasattr(tokenize_record, '_sp'):
tokenize_record._sp = spm.SentencePieceProcessor()
tokenize_record._sp.load(tokenizer_path)
sp = tokenize_record._sp
record = json.loads(record_json)
if not record.get('level_string'):
return None
text = format_training_sample(record)
if not text or len(text) > 5_000_000:
return None
tokens = [sp.bos_id()] + sp.encode(text) + [sp.eos_id()]
if len(tokens) > 10_000_000:
return None
return np.array(tokens, dtype=np.uint16)
except:
return None
def process_levels_parallel(
input_dir: Path,
output_dir: Path,
tokenizer_path: Path,
shard_size: int = 100_000_000,
val_ratio: float = 0.01,
num_workers: int = None,
):
"""Process all levels using multiprocessing and create train/val shards."""
if num_workers is None:
num_workers = max(1, cpu_count() - 1)
print(f"Using {num_workers} workers for tokenization")
# Load tokenizer to get token IDs
print(f"Loading tokenizer from {tokenizer_path}")
sp = spm.SentencePieceProcessor()
sp.load(str(tokenizer_path))
print(f"BOS ID: {sp.bos_id()}, EOS ID: {sp.eos_id()}")
# Find all input files
jsonl_files = list(input_dir.glob("*.jsonl"))
json_files = list(input_dir.glob("*.json"))
level_files = jsonl_files + json_files
if not level_files:
jsonl_files = list(input_dir.glob("**/*.jsonl"))
json_files = list(input_dir.glob("**/*.json"))
level_files = jsonl_files + json_files
print(f"Found {len(level_files)} input files")
# Collect all records as JSON strings
print("Loading records...")
all_records = []
for lf in tqdm(level_files, desc="Reading files"):
try:
if lf.suffix == ".jsonl":
with open(lf, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
all_records.append(line)
else:
with open(lf, "r", encoding="utf-8") as f:
data = json.load(f)
records = data if isinstance(data, list) else [data]
for record in records:
all_records.append(json.dumps(record))
except Exception as e:
print(f"Error reading {lf}: {e}")
print(f"Loaded {len(all_records):,} records")
# Tokenize in parallel
print(f"Tokenizing with {num_workers} workers...")
tokenize_fn = partial(tokenize_record, tokenizer_path=str(tokenizer_path))
all_levels = []
total_tokens = 0
with Pool(num_workers) as pool:
results = list(tqdm(
pool.imap(tokenize_fn, all_records, chunksize=100),
total=len(all_records),
desc="Tokenizing"
))
for tokens in results:
if tokens is not None:
all_levels.append(tokens)
total_tokens += len(tokens)
print(f"Tokenized {len(all_levels):,} levels with {total_tokens:,} total tokens")
# Shuffle and split
np.random.seed(42)
indices = np.random.permutation(len(all_levels))
val_count = max(1, int(len(all_levels) * val_ratio))
val_indices = indices[:val_count]
train_indices = indices[val_count:]
print(f"Train: {len(train_indices):,} levels, Val: {len(val_indices):,} levels")
# Create output directory
output_dir.mkdir(parents=True, exist_ok=True)
# Write train shards
train_tokens = np.concatenate([all_levels[i] for i in tqdm(train_indices, desc="Concatenating train")])
num_train_shards = max(1, len(train_tokens) // shard_size)
print(f"Writing {num_train_shards} train shards...")
shard_size_actual = len(train_tokens) // num_train_shards
for i in tqdm(range(num_train_shards), desc="Writing train shards"):
start = i * shard_size_actual
end = start + shard_size_actual if i < num_train_shards - 1 else len(train_tokens)
shard = train_tokens[start:end]
output_path = output_dir / f"train_{i:04d}.bin"
n = write_bin_file(shard, output_path)
# Write val shards
val_tokens = np.concatenate([all_levels[i] for i in tqdm(val_indices, desc="Concatenating val")])
num_val_shards = max(1, len(val_tokens) // shard_size)
print(f"Writing {num_val_shards} val shards...")
shard_size_actual = len(val_tokens) // num_val_shards
for i in range(num_val_shards):
start = i * shard_size_actual
end = start + shard_size_actual if i < num_val_shards - 1 else len(val_tokens)
shard = val_tokens[start:end]
output_path = output_dir / f"val_{i:04d}.bin"
n = write_bin_file(shard, output_path)
# Summary
print("\n=== Summary ===")
print(f"Train tokens: {len(train_tokens):,}")
print(f"Val tokens: {len(val_tokens):,}")
print(f"Total tokens: {len(train_tokens) + len(val_tokens):,}")
print(f"Output dir: {output_dir}")
def main():
parser = argparse.ArgumentParser(description="Prepare GD levels for modded-nanogpt")
# Input options (either local or HF)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument("--input", "-i", type=Path, help="Input directory/file with level JSONs/JSONL")
input_group.add_argument("--hf-repo", type=str, help="HuggingFace repo ID (e.g., tldne/gd-levels)")
parser.add_argument("--hf-file", type=str, default="levels_deduped.jsonl", help="File to download from HF repo")
parser.add_argument("--output", "-o", type=Path, required=True, help="Output directory for .bin files")
parser.add_argument("--tokenizer", "-t", type=Path, required=True, help="Path to tokenizer.model")
parser.add_argument("--shard-size", type=int, default=100_000_000, help="Tokens per shard")
parser.add_argument("--val-ratio", type=float, default=0.01, help="Validation split ratio")
parser.add_argument("--workers", "-w", type=int, default=None, help="Number of workers (default: cpu_count - 1)")
args = parser.parse_args()
# Handle HuggingFace download
if args.hf_repo:
print(f"Downloading {args.hf_file} from {args.hf_repo}...")
downloaded_path = hf_hub_download(
repo_id=args.hf_repo,
filename=args.hf_file,
repo_type="dataset",
)
print(f"Downloaded to: {downloaded_path}")
import tempfile
import shutil
temp_dir = Path(tempfile.mkdtemp())
shutil.copy(downloaded_path, temp_dir / args.hf_file)
input_dir = temp_dir
else:
input_dir = args.input
temp_dir = None
process_levels_parallel(
input_dir=input_dir,
output_dir=args.output,
tokenizer_path=args.tokenizer,
shard_size=args.shard_size,
val_ratio=args.val_ratio,
num_workers=args.workers,
)
# Cleanup temp dir if created
if temp_dir:
import shutil
shutil.rmtree(temp_dir)
if __name__ == "__main__":
main()