AbstractPhil's picture
Create trainer.py
65782fe verified
#################################################################################
## penta-classifier-prototype
#################################################################################
## Author: AbstractPhil
## Assistant: Claude Opus 4.1
#################################################################################
## License Apache - cite with care and share with passionate individuals.
##
## This tiny model somehow defeated all my larger variants.
## The first model showing direct evidence of potential pentachora scaling.
## No pretraining, pure noise. Nothing bulky or extra, just run it.
##
## Somehow, this model contains 60+ classifiers in 3 pentachora.
## I'm still uncertain as to why, as it defeated the projections.
## I need additional research, additional time. But here's the model.
##
## This is based on one of my earlier prototypes and thus is labeled.
## Somehow over the development it fell apart, today I put it together again.
##
#################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from huggingface_hub import HfApi, create_repo, upload_folder
from safetensors.torch import save_file, load_file
import os
import json
import hashlib
from datetime import datetime
from google.colab import userdata
# ============== SETUP HF AND PATHS ==============
HF_TOKEN = userdata.get('HF_TOKEN')
REPO_ID = "AbstractPhil/penta-classifier-prototype"
# Create unique run ID based on timestamp and config
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_str = f"emnist_byclass_b1024_lr1e-3_{run_timestamp}"
run_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
# Local directories
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("tensorboard_logs", exist_ok=True)
# TensorBoard setup
writer = SummaryWriter(f'tensorboard_logs/{run_hash}')
# Initialize HF API
api = HfApi()
try:
create_repo(REPO_ID, repo_type="model", token=HF_TOKEN, exist_ok=True)
print(f"Using HuggingFace repo: {REPO_ID}")
except Exception as e:
print(f"Repo setup: {e}")
# ============== CONFIGURATION ==============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# Hyperparameters
config = {
"input_dim": 28 * 28,
"base_dim": 64,
"batch_size": 1024,
"epochs": 5,
"initial_lr": 1e-3,
"temp_contrastive": 0.1,
"lambda_contrastive": 0.5,
"lambda_cayley": 0.01,
"dataset": "EMNIST_byclass",
"run_hash": run_hash,
"timestamp": run_timestamp
}
# Save config
config_path = f"checkpoints/config_{run_hash}.json"
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
# Log config to TensorBoard
writer.add_text('Config', json.dumps(config, indent=2), 0)
# ============== DATASET ==============
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.EMNIST(root="./data", split='byclass', train=True, transform=transform, download=True)
test_dataset = datasets.EMNIST(root="./data", split='byclass', train=False, transform=transform, download=True)
num_classes = len(train_dataset.classes)
config["num_classes"] = num_classes
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], pin_memory=True,
shuffle=True, num_workers=4, prefetch_factor=8)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], pin_memory=True,
shuffle=False, num_workers=4, prefetch_factor=8)
print(f"Train: {len(train_dataset)} samples, Test: {len(test_dataset)} samples")
print(f"Classes: {num_classes}")
# ============== MODEL DEFINITIONS ==============
class AdaptiveEncoder(nn.Module):
"""Multi-layer encoder with normalization and multi-scale outputs"""
def __init__(self, input_dim, base_dim=128):
super().__init__()
self.fc1 = nn.Linear(input_dim, 512)
self.bn1 = nn.BatchNorm1d(512)
self.dropout1 = nn.Dropout(0.2)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.dropout2 = nn.Dropout(0.2)
self.fc3 = nn.Linear(256, 128)
self.bn3 = nn.BatchNorm1d(128)
self.fc_coarse = nn.Linear(256, base_dim // 4)
self.fc_medium = nn.Linear(128, base_dim // 2)
self.fc_fine = nn.Linear(128, base_dim)
self.norm_coarse = nn.LayerNorm(base_dim // 4)
self.norm_medium = nn.LayerNorm(base_dim // 2)
self.norm_fine = nn.LayerNorm(base_dim)
def forward(self, x):
h1 = F.relu(self.bn1(self.fc1(x)))
h1 = self.dropout1(h1)
h2 = F.relu(self.bn2(self.fc2(h1)))
h2 = self.dropout2(h2)
h3 = F.relu(self.bn3(self.fc3(h2)))
coarse = self.norm_coarse(self.fc_coarse(h2))
medium = self.norm_medium(self.fc_medium(h3))
fine = self.norm_fine(self.fc_fine(h3))
return coarse, medium, fine
def init_perfect_pentachora(num_classes, latent_dim, device='cuda'):
"""Initialize as regular 4-simplices in orthogonal subspaces"""
pentachora = torch.zeros(num_classes, 5, latent_dim, device=device)
sqrt15 = np.sqrt(15)
sqrt10 = np.sqrt(10)
sqrt5 = np.sqrt(5)
simplex = torch.tensor([
[1.0, 0.0, 0.0, 0.0],
[-0.25, sqrt15/4, 0.0, 0.0],
[-0.25, -sqrt15/12, sqrt10/3, 0.0],
[-0.25, -sqrt15/12, -sqrt10/6, sqrt5/2],
[-0.25, -sqrt15/12, -sqrt10/6, -sqrt5/2]
], dtype=torch.float32, device=device)
simplex = F.normalize(simplex, dim=1)
dims_per_class = latent_dim // num_classes
for c in range(num_classes):
if dims_per_class >= 4:
start = c * dims_per_class
pentachora[c, :, start:start+4] = simplex
else:
rotation = torch.randn(4, latent_dim, device=device)
rotation = F.normalize(rotation, dim=1)
pentachora[c] = torch.mm(simplex, rotation[:4])
return nn.Parameter(pentachora * 2.0)
class PerfectPentachoron(nn.Module):
"""Multi-scale pentachoron with learnable metric and vertex weights"""
def __init__(self, num_classes, base_dim, device='cuda'):
super().__init__()
self.device = device
self.num_classes = num_classes
self.base_dim = base_dim
self.penta_coarse = init_perfect_pentachora(num_classes, base_dim // 4, device)
self.penta_medium = init_perfect_pentachora(num_classes, base_dim // 2, device)
self.penta_fine = init_perfect_pentachora(num_classes, base_dim, device)
self.vertex_weights = nn.Parameter(torch.ones(num_classes, 5, device=device) / 5)
self.metric_coarse = nn.Parameter(torch.eye(base_dim // 4, device=device))
self.metric_medium = nn.Parameter(torch.eye(base_dim // 2, device=device))
self.metric_fine = nn.Parameter(torch.eye(base_dim, device=device))
self.scale_weights = nn.Parameter(torch.tensor([0.2, 0.3, 0.5], device=device))
def mahalanobis_distance(self, x, pentachora, metric):
x_trans = torch.matmul(x, metric)
p_trans = torch.einsum('cpd,de->cpe', pentachora, metric)
diffs = p_trans.unsqueeze(0) - x_trans.unsqueeze(1).unsqueeze(2)
dists = torch.norm(diffs, dim=-1)
return dists
def forward(self, x_coarse, x_medium, x_fine):
dists_c = self.mahalanobis_distance(x_coarse, self.penta_coarse, self.metric_coarse)
dists_m = self.mahalanobis_distance(x_medium, self.penta_medium, self.metric_medium)
dists_f = self.mahalanobis_distance(x_fine, self.penta_fine, self.metric_fine)
weights = F.softmax(self.vertex_weights, dim=1).unsqueeze(0)
dists_c = dists_c * weights
dists_m = dists_m * weights
dists_f = dists_f * weights
scores_c = -dists_c.sum(dim=-1)
scores_m = -dists_m.sum(dim=-1)
scores_f = -dists_f.sum(dim=-1)
w = F.softmax(self.scale_weights, dim=0)
scores = w[0] * scores_c + w[1] * scores_m + w[2] * scores_f
return scores, (dists_c, dists_m, dists_f)
def regularization_loss(self):
mask = torch.triu(torch.ones(5, 5, device=self.device), diagonal=1).bool()
diffs_c = self.penta_coarse.unsqueeze(2) - self.penta_coarse.unsqueeze(1)
dists_c = torch.norm(diffs_c, dim=-1)
edges_c = dists_c[:, mask]
diffs_m = self.penta_medium.unsqueeze(2) - self.penta_medium.unsqueeze(1)
dists_m = torch.norm(diffs_m, dim=-1)
edges_m = dists_m[:, mask]
diffs_f = self.penta_fine.unsqueeze(2) - self.penta_fine.unsqueeze(1)
dists_f = torch.norm(diffs_f, dim=-1)
edges_f = dists_f[:, mask]
all_edges = torch.stack([edges_c, edges_m, edges_f], dim=0)
edge_var = torch.var(all_edges, dim=2).mean()
min_edges = torch.min(all_edges, dim=2)[0]
collapse_penalty = torch.relu(0.5 - min_edges).mean()
return edge_var + collapse_penalty
def contrastive_pentachoron_loss_batched(latents, targets, pentachora, temp=0.1):
batch_size = latents.size(0)
num_classes = pentachora.size(0)
diffs = latents.unsqueeze(1).unsqueeze(2) - pentachora.unsqueeze(0)
dists = torch.norm(diffs, dim=-1)
min_dists, _ = torch.min(dists, dim=2)
sims = -min_dists / temp
targets_one_hot = F.one_hot(targets, num_classes).float()
max_sims, _ = torch.max(sims, dim=1, keepdim=True)
exp_sims = torch.exp(sims - max_sims)
pos_sims = torch.sum(exp_sims * targets_one_hot, dim=1)
all_sims = torch.sum(exp_sims, dim=1)
loss = -torch.log(pos_sims / all_sims).mean()
return loss
# ============== TRAINING SETUP ==============
encoder = AdaptiveEncoder(config["input_dim"], config["base_dim"]).to(device)
classifier = PerfectPentachoron(num_classes, config["base_dim"], device).to(device)
# Try to compile if available
try:
encoder = torch.compile(encoder)
classifier = torch.compile(classifier)
print("Models compiled successfully")
except:
print("Torch compile not available, using eager mode")
optimizer = torch.optim.AdamW([
{'params': encoder.parameters(), 'lr': config["initial_lr"]},
{'params': classifier.parameters(), 'lr': config["initial_lr"] * 0.5}
], weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])
# ============== CHECKPOINT FUNCTIONS ==============
def save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False):
"""Save checkpoint as safetensors with proper organization"""
# Prepare state dict for safetensors
encoder_state = {f"encoder.{k}": v.cpu() for k, v in encoder.state_dict().items()}
classifier_state = {f"classifier.{k}": v.cpu() for k, v in classifier.state_dict().items()}
# Combine all model weights
model_state = {**encoder_state, **classifier_state}
# Save model weights as safetensors
checkpoint_name = f"checkpoint_{run_hash}_epoch_{epoch:03d}.safetensors"
if is_best:
checkpoint_name = f"best_{run_hash}.safetensors"
checkpoint_path = os.path.join("checkpoints", checkpoint_name)
save_file(model_state, checkpoint_path)
# Save training state separately (optimizer, scheduler, metrics)
training_state = {
'epoch': epoch,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'metrics': metrics,
'config': config
}
state_path = checkpoint_path.replace('.safetensors', '_state.pt')
torch.save(training_state, state_path)
print(f"Saved checkpoint: {checkpoint_name}")
# Upload to HuggingFace
try:
# Create organized structure
upload_folder(
folder_path="checkpoints",
repo_id=REPO_ID,
repo_type="model",
token=HF_TOKEN,
path_in_repo=f"weights/{run_hash}",
commit_message=f"Epoch {epoch} - Test Acc: {metrics['test_acc']:.4f}"
)
# Upload tensorboard logs
upload_folder(
folder_path=f"tensorboard_logs/{run_hash}",
repo_id=REPO_ID,
repo_type="model",
token=HF_TOKEN,
path_in_repo=f"runs/{run_hash}",
commit_message=f"TensorBoard logs - Epoch {epoch}"
)
except Exception as e:
print(f"HF upload error: {e}")
# ============== TRAINING FUNCTIONS ==============
def train_epoch(epoch):
encoder.train()
classifier.train()
total_loss = 0.0
total_ce = 0.0
total_contr = 0.0
total_reg = 0.0
correct = 0
total = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
for batch_idx, (inputs, targets) in enumerate(pbar):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
x_coarse, x_medium, x_fine = encoder(inputs)
scores, all_dists = classifier(x_coarse, x_medium, x_fine)
ce_loss = F.cross_entropy(scores, targets)
contr_c = contrastive_pentachoron_loss_batched(x_coarse, targets, classifier.penta_coarse, config["temp_contrastive"])
contr_m = contrastive_pentachoron_loss_batched(x_medium, targets, classifier.penta_medium, config["temp_contrastive"])
contr_f = contrastive_pentachoron_loss_batched(x_fine, targets, classifier.penta_fine, config["temp_contrastive"])
contr_loss = (contr_c + contr_m + contr_f) / 3
reg_loss = classifier.regularization_loss()
loss = ce_loss + config["lambda_contrastive"] * contr_loss + config["lambda_cayley"] * reg_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
optimizer.step()
total_loss += loss.item() * inputs.size(0)
total_ce += ce_loss.item() * inputs.size(0)
total_contr += contr_loss.item() * inputs.size(0)
total_reg += reg_loss.item() * inputs.size(0)
preds = scores.argmax(dim=1)
correct += (preds == targets).sum().item()
total += inputs.size(0)
# Log batch metrics to TensorBoard
if batch_idx % 50 == 0:
global_step = epoch * len(train_loader) + batch_idx
writer.add_scalar('Train/BatchLoss', loss.item(), global_step)
writer.add_scalar('Train/BatchAcc', correct/total, global_step)
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{correct/total:.4f}",
'lr': f"{optimizer.param_groups[0]['lr']:.1e}"
})
return (total_loss/total, total_ce/total, total_contr/total,
total_reg/total, correct/total)
@torch.no_grad()
def evaluate():
encoder.eval()
classifier.eval()
correct = 0
total = 0
class_correct = [0] * num_classes
class_total = [0] * num_classes
pbar = tqdm(test_loader, desc="Evaluating")
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
x_coarse, x_medium, x_fine = encoder(inputs)
scores, _ = classifier(x_coarse, x_medium, x_fine)
preds = scores.argmax(dim=1)
correct += (preds == targets).sum().item()
total += inputs.size(0)
for i in range(targets.size(0)):
label = targets[i].item()
class_total[label] += 1
if preds[i] == targets[i]:
class_correct[label] += 1
pbar.set_postfix({'acc': f"{correct/total:.4f}"})
class_accs = [class_correct[i]/max(1, class_total[i]) for i in range(num_classes)]
return correct/total, class_accs
# ============== MAIN TRAINING LOOP ==============
print("\n" + "="*60)
print(f"PERFECT PENTACHORON TRAINING - Run {run_hash}")
print("="*60 + "\n")
best_acc = 0.0
train_history = []
test_history = []
patience = 7
no_improve = 0
for epoch in range(config["epochs"]):
# Train
train_loss, train_ce, train_contr, train_reg, train_acc = train_epoch(epoch)
train_history.append(train_acc)
# Evaluate
test_acc, class_accs = evaluate()
test_history.append(test_acc)
# Log to TensorBoard
writer.add_scalar('Loss/Total', train_loss, epoch)
writer.add_scalar('Loss/CE', train_ce, epoch)
writer.add_scalar('Loss/Contrastive', train_contr, epoch)
writer.add_scalar('Loss/Regularization', train_reg, epoch)
writer.add_scalar('Accuracy/Train', train_acc, epoch)
writer.add_scalar('Accuracy/Test', test_acc, epoch)
writer.add_scalar('Learning/LR', optimizer.param_groups[0]['lr'], epoch)
writer.add_scalar('Learning/Generalization_Gap', train_acc - test_acc, epoch)
# Log per-class accuracies
for i, acc in enumerate(class_accs[:10]): # Log first 10 classes
writer.add_scalar(f'ClassAcc/Class_{i}', acc, epoch)
# Log scale weights
scale_weights = F.softmax(classifier.scale_weights, dim=0)
writer.add_scalar('Scales/Coarse', scale_weights[0], epoch)
writer.add_scalar('Scales/Medium', scale_weights[1], epoch)
writer.add_scalar('Scales/Fine', scale_weights[2], epoch)
scheduler.step()
# Print results
print(f"\n[Epoch {epoch+1}/{config['epochs']}]")
print(f"Train | Loss: {train_loss:.4f} | CE: {train_ce:.4f} | "
f"Contr: {train_contr:.4f} | Reg: {train_reg:.4f} | Acc: {train_acc:.4f}")
print(f"Test | Acc: {test_acc:.4f} | Best: {best_acc:.4f}")
# Save checkpoint
metrics = {
'train_acc': train_acc,
'test_acc': test_acc,
'train_loss': train_loss,
'class_accs': class_accs
}
# Check if best
if test_acc > best_acc:
best_acc = test_acc
no_improve = 0
print(f"NEW BEST! Saving checkpoint...")
save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=True)
else:
no_improve += 1
if (epoch + 1) % 5 == 0: # Save every 5 epochs
save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics)
# Early stopping
if no_improve >= patience:
print(f"Early stopping triggered (no improvement for {patience} epochs)")
break
# ============== FINAL RESULTS ==============
print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"Best Test Accuracy: {best_acc:.4f}")
print(f"Final Train Accuracy: {train_history[-1]:.4f}")
print(f"Generalization Gap: {train_history[-1] - test_history[-1]:.4f}")
# Save final model
save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False)
# Log final pentachoron geometry
with torch.no_grad():
vertex_importance = F.softmax(classifier.vertex_weights, dim=1)
scale_weights = F.softmax(classifier.scale_weights, dim=0).cpu().numpy()
geometry_info = {
'scale_importance': {
'coarse': float(scale_weights[0]),
'medium': float(scale_weights[1]),
'fine': float(scale_weights[2])
},
'dominant_vertices': {}
}
for c in range(min(10, num_classes)):
weights = vertex_importance[c].cpu().numpy()
dominant = np.argmax(weights)
geometry_info['dominant_vertices'][f'class_{c}'] = {
'vertex': int(dominant),
'weight': float(weights[dominant])
}
writer.add_text('Final_Geometry', json.dumps(geometry_info, indent=2), epoch)
writer.close()
print(f"\n✨ Training Complete! Run hash: {run_hash}")
print(f"Results uploaded to: https://huggingface.co/{REPO_ID}")
print(f"TensorBoard: tensorboard --logdir tensorboard_logs/{run_hash}")