| """ |
| Helion-OSC Training Script |
| Fine-tuning and training utilities for Helion-OSC model |
| """ |
|
|
| import os |
| import torch |
| import json |
| import logging |
| from typing import Optional, Dict, Any, List |
| from dataclasses import dataclass, field |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling, |
| EarlyStoppingCallback |
| ) |
| from datasets import load_dataset, Dataset, DatasetDict |
| from peft import ( |
| LoraConfig, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| TaskType |
| ) |
| import wandb |
| from torch.utils.data import DataLoader |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """Arguments for model configuration""" |
| model_name_or_path: str = field( |
| default="DeepXR/Helion-OSC", |
| metadata={"help": "Path to pretrained model or model identifier"} |
| ) |
| use_lora: bool = field( |
| default=True, |
| metadata={"help": "Whether to use LoRA for efficient fine-tuning"} |
| ) |
| lora_r: int = field( |
| default=16, |
| metadata={"help": "LoRA attention dimension"} |
| ) |
| lora_alpha: int = field( |
| default=32, |
| metadata={"help": "LoRA alpha parameter"} |
| ) |
| lora_dropout: float = field( |
| default=0.05, |
| metadata={"help": "LoRA dropout probability"} |
| ) |
| load_in_8bit: bool = field( |
| default=False, |
| metadata={"help": "Load model in 8-bit precision"} |
| ) |
| load_in_4bit: bool = field( |
| default=False, |
| metadata={"help": "Load model in 4-bit precision"} |
| ) |
|
|
|
|
| @dataclass |
| class DataArguments: |
| """Arguments for data processing""" |
| dataset_name: Optional[str] = field( |
| default=None, |
| metadata={"help": "Name of the dataset to use"} |
| ) |
| dataset_path: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to local dataset"} |
| ) |
| train_file: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to training data file"} |
| ) |
| validation_file: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to validation data file"} |
| ) |
| max_seq_length: int = field( |
| default=2048, |
| metadata={"help": "Maximum sequence length"} |
| ) |
| preprocessing_num_workers: int = field( |
| default=4, |
| metadata={"help": "Number of workers for preprocessing"} |
| ) |
|
|
|
|
| class HelionOSCTrainer: |
| """Trainer class for Helion-OSC model""" |
| |
| def __init__( |
| self, |
| model_args: ModelArguments, |
| data_args: DataArguments, |
| training_args: TrainingArguments |
| ): |
| self.model_args = model_args |
| self.data_args = data_args |
| self.training_args = training_args |
| |
| |
| self.tokenizer = self._load_tokenizer() |
| |
| |
| self.model = self._load_model() |
| |
| |
| self.datasets = self._load_datasets() |
| |
| logger.info("Trainer initialized successfully") |
| |
| def _load_tokenizer(self): |
| """Load and configure tokenizer""" |
| logger.info("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| self.model_args.model_name_or_path, |
| trust_remote_code=True, |
| padding_side="right" |
| ) |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| return tokenizer |
| |
| def _load_model(self): |
| """Load and configure model""" |
| logger.info("Loading model...") |
| |
| model_kwargs = { |
| "trust_remote_code": True, |
| "low_cpu_mem_usage": True |
| } |
| |
| |
| if self.model_args.load_in_8bit: |
| model_kwargs["load_in_8bit"] = True |
| elif self.model_args.load_in_4bit: |
| model_kwargs["load_in_4bit"] = True |
| model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 |
| model_kwargs["bnb_4bit_use_double_quant"] = True |
| model_kwargs["bnb_4bit_quant_type"] = "nf4" |
| else: |
| model_kwargs["torch_dtype"] = torch.bfloat16 |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| self.model_args.model_name_or_path, |
| **model_kwargs |
| ) |
| |
| |
| if self.model_args.use_lora: |
| logger.info("Applying LoRA configuration...") |
| |
| if self.model_args.load_in_8bit or self.model_args.load_in_4bit: |
| model = prepare_model_for_kbit_training(model) |
| |
| lora_config = LoraConfig( |
| r=self.model_args.lora_r, |
| lora_alpha=self.model_args.lora_alpha, |
| target_modules=[ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj" |
| ], |
| lora_dropout=self.model_args.lora_dropout, |
| bias="none", |
| task_type=TaskType.CAUSAL_LM |
| ) |
| |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| |
| return model |
| |
| def _load_datasets(self) -> DatasetDict: |
| """Load and preprocess datasets""" |
| logger.info("Loading datasets...") |
| |
| if self.data_args.dataset_name: |
| |
| datasets = load_dataset(self.data_args.dataset_name) |
| elif self.data_args.train_file: |
| |
| data_files = {"train": self.data_args.train_file} |
| if self.data_args.validation_file: |
| data_files["validation"] = self.data_args.validation_file |
| |
| datasets = load_dataset("json", data_files=data_files) |
| else: |
| raise ValueError("Must provide either dataset_name or train_file") |
| |
| |
| logger.info("Preprocessing datasets...") |
| datasets = datasets.map( |
| self._preprocess_function, |
| batched=True, |
| num_proc=self.data_args.preprocessing_num_workers, |
| remove_columns=datasets["train"].column_names, |
| desc="Preprocessing datasets" |
| ) |
| |
| return datasets |
| |
| def _preprocess_function(self, examples): |
| """Preprocess examples for training""" |
| |
| if "prompt" in examples and "completion" in examples: |
| |
| texts = [ |
| f"{prompt}\n{completion}" |
| for prompt, completion in zip(examples["prompt"], examples["completion"]) |
| ] |
| elif "text" in examples: |
| |
| texts = examples["text"] |
| else: |
| raise ValueError("Dataset must contain 'text' or 'prompt'/'completion' columns") |
| |
| |
| tokenized = self.tokenizer( |
| texts, |
| truncation=True, |
| max_length=self.data_args.max_seq_length, |
| padding="max_length", |
| return_tensors=None |
| ) |
| |
| |
| tokenized["labels"] = tokenized["input_ids"].copy() |
| |
| return tokenized |
| |
| def train(self): |
| """Train the model""" |
| logger.info("Starting training...") |
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False |
| ) |
| |
| |
| trainer = Trainer( |
| model=self.model, |
| args=self.training_args, |
| train_dataset=self.datasets["train"], |
| eval_dataset=self.datasets.get("validation"), |
| tokenizer=self.tokenizer, |
| data_collator=data_collator, |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] |
| ) |
| |
| |
| train_result = trainer.train() |
| |
| |
| trainer.save_model() |
| |
| |
| metrics = train_result.metrics |
| trainer.log_metrics("train", metrics) |
| trainer.save_metrics("train", metrics) |
| trainer.save_state() |
| |
| logger.info("Training completed successfully!") |
| |
| return trainer, metrics |
| |
| def evaluate(self, trainer: Optional[Trainer] = None): |
| """Evaluate the model""" |
| if trainer is None: |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False |
| ) |
| |
| trainer = Trainer( |
| model=self.model, |
| args=self.training_args, |
| eval_dataset=self.datasets.get("validation"), |
| tokenizer=self.tokenizer, |
| data_collator=data_collator |
| ) |
| |
| logger.info("Evaluating model...") |
| metrics = trainer.evaluate() |
| |
| trainer.log_metrics("eval", metrics) |
| trainer.save_metrics("eval", metrics) |
| |
| return metrics |
|
|
|
|
| def create_code_dataset(examples: List[Dict[str, str]]) -> Dataset: |
| """ |
| Create a dataset from code examples |
| |
| Args: |
| examples: List of dictionaries with 'prompt' and 'completion' keys |
| |
| Returns: |
| Dataset object |
| """ |
| return Dataset.from_dict({ |
| "prompt": [ex["prompt"] for ex in examples], |
| "completion": [ex["completion"] for ex in examples] |
| }) |
|
|
|
|
| def create_math_dataset(examples: List[Dict[str, str]]) -> Dataset: |
| """ |
| Create a dataset from math examples |
| |
| Args: |
| examples: List of dictionaries with 'problem' and 'solution' keys |
| |
| Returns: |
| Dataset object |
| """ |
| return Dataset.from_dict({ |
| "prompt": [f"Problem: {ex['problem']}\nSolution:" for ex in examples], |
| "completion": [ex["solution"] for ex in examples] |
| }) |
|
|
|
|
| def main(): |
| """Main training script""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Train Helion-OSC model") |
| |
| |
| parser.add_argument("--model_name_or_path", type=str, default="DeepXR/Helion-OSC") |
| parser.add_argument("--use_lora", action="store_true", default=True) |
| parser.add_argument("--lora_r", type=int, default=16) |
| parser.add_argument("--lora_alpha", type=int, default=32) |
| parser.add_argument("--lora_dropout", type=float, default=0.05) |
| parser.add_argument("--load_in_8bit", action="store_true") |
| parser.add_argument("--load_in_4bit", action="store_true") |
| |
| |
| parser.add_argument("--dataset_name", type=str, default=None) |
| parser.add_argument("--dataset_path", type=str, default=None) |
| parser.add_argument("--train_file", type=str, required=True) |
| parser.add_argument("--validation_file", type=str, default=None) |
| parser.add_argument("--max_seq_length", type=int, default=2048) |
| parser.add_argument("--preprocessing_num_workers", type=int, default=4) |
| |
| |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--num_train_epochs", type=int, default=3) |
| parser.add_argument("--per_device_train_batch_size", type=int, default=4) |
| parser.add_argument("--per_device_eval_batch_size", type=int, default=4) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=2e-5) |
| parser.add_argument("--warmup_steps", type=int, default=100) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument("--save_steps", type=int, default=500) |
| parser.add_argument("--eval_steps", type=int, default=500) |
| parser.add_argument("--save_total_limit", type=int, default=3) |
| parser.add_argument("--fp16", action="store_true") |
| parser.add_argument("--bf16", action="store_true") |
| parser.add_argument("--gradient_checkpointing", action="store_true") |
| parser.add_argument("--use_wandb", action="store_true") |
| |
| args = parser.parse_args() |
| |
| |
| model_args = ModelArguments( |
| model_name_or_path=args.model_name_or_path, |
| use_lora=args.use_lora, |
| lora_r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| load_in_8bit=args.load_in_8bit, |
| load_in_4bit=args.load_in_4bit |
| ) |
| |
| data_args = DataArguments( |
| dataset_name=args.dataset_name, |
| dataset_path=args.dataset_path, |
| train_file=args.train_file, |
| validation_file=args.validation_file, |
| max_seq_length=args.max_seq_length, |
| preprocessing_num_workers=args.preprocessing_num_workers |
| ) |
| |
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| per_device_eval_batch_size=args.per_device_eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_steps=args.warmup_steps, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| eval_steps=args.eval_steps, |
| save_total_limit=args.save_total_limit, |
| fp16=args.fp16, |
| bf16=args.bf16, |
| gradient_checkpointing=args.gradient_checkpointing, |
| report_to="wandb" if args.use_wandb else "none", |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| evaluation_strategy="steps", |
| save_strategy="steps", |
| logging_dir=f"{args.output_dir}/logs", |
| remove_unused_columns=False |
| ) |
| |
| |
| helion_trainer = HelionOSCTrainer( |
| model_args=model_args, |
| data_args=data_args, |
| training_args=training_args |
| ) |
| |
| |
| trainer, metrics = helion_trainer.train() |
| |
| |
| if args.validation_file: |
| eval_metrics = helion_trainer.evaluate(trainer) |
| logger.info(f"Evaluation metrics: {eval_metrics}") |
| |
| logger.info("Training pipeline completed!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |