#!/usr/bin/env python3 """ Prepare GD level data for modded-nanogpt training. Converts tokenized levels to .bin format compatible with the data loader. Uses multiprocessing for fast tokenization. Usage (local files): python prepare_gd_data.py --input data/gd_raw --output data/gd_levels --tokenizer tokenizer.model Usage (HuggingFace dataset): python prepare_gd_data.py --hf-repo tldne/gd-levels --hf-file levels_deduped.jsonl --output data/gd_levels --tokenizer tokenizer.model """ import argparse import json import numpy as np from pathlib import Path import sentencepiece as spm from tqdm import tqdm from huggingface_hub import hf_hub_download from multiprocessing import Pool, cpu_count from functools import partial import os # Only include these fields in training (must match train_tokenizer.py!) INCLUDE_FIELDS = [ 'level_id', # Derived from source_file 'level_name', 'level_type', 'binary_version', 'description_decoded', 'song_id', # level_string handled separately at the end ] def format_training_sample(record: dict) -> str: """ Format a level record as a training string. Only includes specific metadata fields + level_string. Must match train_tokenizer.py exactly! """ parts = [] # Extract level_id from source_file: "Level_12345.gmd2" -> "12345" source_file = record.get('source_file', '') if source_file: level_id = source_file.lstrip("Level_").rstrip(".gmd2") parts.append(f"{level_id}") # Add only the included metadata fields for key in INCLUDE_FIELDS: if key == 'level_id': continue # Already handled above value = record.get(key) if value is not None and value != "": parts.append(f"<{key}>{value}") # Add level string at the end (the main content) if record.get('level_string'): parts.append(f"{record['level_string']}") return "".join(parts) def write_bin_file(tokens: np.ndarray, output_path: Path): """Write tokens to .bin file with modded-nanogpt header format.""" header = np.zeros(256, dtype=np.int32) header[0] = 20240520 # Magic number header[1] = 1 # Version header[2] = len(tokens) with output_path.open("wb") as f: f.write(header.tobytes()) f.write(tokens.astype(np.uint16).tobytes()) return len(tokens) # Worker function for multiprocessing def tokenize_record(record_json: str, tokenizer_path: str) -> np.ndarray | None: """Tokenize a single record. Returns None if should be skipped.""" try: # Load tokenizer in worker (each worker needs its own instance) if not hasattr(tokenize_record, '_sp'): tokenize_record._sp = spm.SentencePieceProcessor() tokenize_record._sp.load(tokenizer_path) sp = tokenize_record._sp record = json.loads(record_json) if not record.get('level_string'): return None text = format_training_sample(record) if not text or len(text) > 5_000_000: return None tokens = [sp.bos_id()] + sp.encode(text) + [sp.eos_id()] if len(tokens) > 10_000_000: return None return np.array(tokens, dtype=np.uint16) except: return None def process_levels_parallel( input_dir: Path, output_dir: Path, tokenizer_path: Path, shard_size: int = 100_000_000, val_ratio: float = 0.01, num_workers: int = None, ): """Process all levels using multiprocessing and create train/val shards.""" if num_workers is None: num_workers = max(1, cpu_count() - 1) print(f"Using {num_workers} workers for tokenization") # Load tokenizer to get token IDs print(f"Loading tokenizer from {tokenizer_path}") sp = spm.SentencePieceProcessor() sp.load(str(tokenizer_path)) print(f"BOS ID: {sp.bos_id()}, EOS ID: {sp.eos_id()}") # Find all input files jsonl_files = list(input_dir.glob("*.jsonl")) json_files = list(input_dir.glob("*.json")) level_files = jsonl_files + json_files if not level_files: jsonl_files = list(input_dir.glob("**/*.jsonl")) json_files = list(input_dir.glob("**/*.json")) level_files = jsonl_files + json_files print(f"Found {len(level_files)} input files") # Collect all records as JSON strings print("Loading records...") all_records = [] for lf in tqdm(level_files, desc="Reading files"): try: if lf.suffix == ".jsonl": with open(lf, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: all_records.append(line) else: with open(lf, "r", encoding="utf-8") as f: data = json.load(f) records = data if isinstance(data, list) else [data] for record in records: all_records.append(json.dumps(record)) except Exception as e: print(f"Error reading {lf}: {e}") print(f"Loaded {len(all_records):,} records") # Tokenize in parallel print(f"Tokenizing with {num_workers} workers...") tokenize_fn = partial(tokenize_record, tokenizer_path=str(tokenizer_path)) all_levels = [] total_tokens = 0 with Pool(num_workers) as pool: results = list(tqdm( pool.imap(tokenize_fn, all_records, chunksize=100), total=len(all_records), desc="Tokenizing" )) for tokens in results: if tokens is not None: all_levels.append(tokens) total_tokens += len(tokens) print(f"Tokenized {len(all_levels):,} levels with {total_tokens:,} total tokens") # Shuffle and split np.random.seed(42) indices = np.random.permutation(len(all_levels)) val_count = max(1, int(len(all_levels) * val_ratio)) val_indices = indices[:val_count] train_indices = indices[val_count:] print(f"Train: {len(train_indices):,} levels, Val: {len(val_indices):,} levels") # Create output directory output_dir.mkdir(parents=True, exist_ok=True) # Write train shards train_tokens = np.concatenate([all_levels[i] for i in tqdm(train_indices, desc="Concatenating train")]) num_train_shards = max(1, len(train_tokens) // shard_size) print(f"Writing {num_train_shards} train shards...") shard_size_actual = len(train_tokens) // num_train_shards for i in tqdm(range(num_train_shards), desc="Writing train shards"): start = i * shard_size_actual end = start + shard_size_actual if i < num_train_shards - 1 else len(train_tokens) shard = train_tokens[start:end] output_path = output_dir / f"train_{i:04d}.bin" n = write_bin_file(shard, output_path) # Write val shards val_tokens = np.concatenate([all_levels[i] for i in tqdm(val_indices, desc="Concatenating val")]) num_val_shards = max(1, len(val_tokens) // shard_size) print(f"Writing {num_val_shards} val shards...") shard_size_actual = len(val_tokens) // num_val_shards for i in range(num_val_shards): start = i * shard_size_actual end = start + shard_size_actual if i < num_val_shards - 1 else len(val_tokens) shard = val_tokens[start:end] output_path = output_dir / f"val_{i:04d}.bin" n = write_bin_file(shard, output_path) # Summary print("\n=== Summary ===") print(f"Train tokens: {len(train_tokens):,}") print(f"Val tokens: {len(val_tokens):,}") print(f"Total tokens: {len(train_tokens) + len(val_tokens):,}") print(f"Output dir: {output_dir}") def main(): parser = argparse.ArgumentParser(description="Prepare GD levels for modded-nanogpt") # Input options (either local or HF) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument("--input", "-i", type=Path, help="Input directory/file with level JSONs/JSONL") input_group.add_argument("--hf-repo", type=str, help="HuggingFace repo ID (e.g., tldne/gd-levels)") parser.add_argument("--hf-file", type=str, default="levels_deduped.jsonl", help="File to download from HF repo") parser.add_argument("--output", "-o", type=Path, required=True, help="Output directory for .bin files") parser.add_argument("--tokenizer", "-t", type=Path, required=True, help="Path to tokenizer.model") parser.add_argument("--shard-size", type=int, default=100_000_000, help="Tokens per shard") parser.add_argument("--val-ratio", type=float, default=0.01, help="Validation split ratio") parser.add_argument("--workers", "-w", type=int, default=None, help="Number of workers (default: cpu_count - 1)") args = parser.parse_args() # Handle HuggingFace download if args.hf_repo: print(f"Downloading {args.hf_file} from {args.hf_repo}...") downloaded_path = hf_hub_download( repo_id=args.hf_repo, filename=args.hf_file, repo_type="dataset", ) print(f"Downloaded to: {downloaded_path}") import tempfile import shutil temp_dir = Path(tempfile.mkdtemp()) shutil.copy(downloaded_path, temp_dir / args.hf_file) input_dir = temp_dir else: input_dir = args.input temp_dir = None process_levels_parallel( input_dir=input_dir, output_dir=args.output, tokenizer_path=args.tokenizer, shard_size=args.shard_size, val_ratio=args.val_ratio, num_workers=args.workers, ) # Cleanup temp dir if created if temp_dir: import shutil shutil.rmtree(temp_dir) if __name__ == "__main__": main()