| |
| """ |
| Helion-2.5-Rnd Batch Inference |
| Efficient batch processing for large-scale inference tasks |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Optional, Union |
|
|
| import pandas as pd |
| from tqdm import tqdm |
|
|
| from inference.client import HelionClient |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BatchProcessor: |
| """Process large batches of inference requests""" |
| |
| def __init__( |
| self, |
| client: HelionClient, |
| batch_size: int = 10, |
| max_retries: int = 3, |
| retry_delay: float = 1.0 |
| ): |
| """ |
| Initialize batch processor |
| |
| Args: |
| client: HelionClient instance |
| batch_size: Number of requests to process concurrently |
| max_retries: Maximum retry attempts for failed requests |
| retry_delay: Delay between retries in seconds |
| """ |
| self.client = client |
| self.batch_size = batch_size |
| self.max_retries = max_retries |
| self.retry_delay = retry_delay |
| |
| self.stats = { |
| 'total': 0, |
| 'successful': 0, |
| 'failed': 0, |
| 'total_time': 0.0, |
| 'avg_time_per_request': 0.0 |
| } |
| |
| def process_prompts( |
| self, |
| prompts: List[str], |
| temperature: float = 0.7, |
| max_tokens: int = 1024, |
| **kwargs |
| ) -> List[Dict]: |
| """ |
| Process a list of prompts |
| |
| Args: |
| prompts: List of input prompts |
| temperature: Sampling temperature |
| max_tokens: Maximum tokens per response |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| List of results with prompt, response, and metadata |
| """ |
| results = [] |
| start_time = time.time() |
| |
| logger.info(f"Processing {len(prompts)} prompts...") |
| |
| for i in tqdm(range(0, len(prompts), self.batch_size)): |
| batch = prompts[i:i + self.batch_size] |
| |
| for prompt in batch: |
| result = self._process_single_with_retry( |
| prompt, |
| temperature, |
| max_tokens, |
| **kwargs |
| ) |
| results.append(result) |
| |
| |
| self.stats['total'] = len(prompts) |
| self.stats['successful'] = sum(1 for r in results if r['success']) |
| self.stats['failed'] = len(prompts) - self.stats['successful'] |
| self.stats['total_time'] = time.time() - start_time |
| self.stats['avg_time_per_request'] = self.stats['total_time'] / len(prompts) |
| |
| logger.info(f"Batch processing complete. Success rate: {self.stats['successful']}/{self.stats['total']}") |
| |
| return results |
| |
| def _process_single_with_retry( |
| self, |
| prompt: str, |
| temperature: float, |
| max_tokens: int, |
| **kwargs |
| ) -> Dict: |
| """Process single prompt with retry logic""" |
| for attempt in range(self.max_retries): |
| try: |
| start = time.time() |
| response = self.client.complete( |
| prompt=prompt, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| **kwargs |
| ) |
| duration = time.time() - start |
| |
| return { |
| 'prompt': prompt, |
| 'response': response, |
| 'success': True, |
| 'duration': duration, |
| 'attempts': attempt + 1 |
| } |
| |
| except Exception as e: |
| logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
| |
| if attempt < self.max_retries - 1: |
| time.sleep(self.retry_delay) |
| else: |
| return { |
| 'prompt': prompt, |
| 'response': None, |
| 'success': False, |
| 'error': str(e), |
| 'attempts': attempt + 1 |
| } |
| |
| def process_chat_conversations( |
| self, |
| conversations: List[List[Dict]], |
| temperature: float = 0.7, |
| max_tokens: int = 1024, |
| **kwargs |
| ) -> List[Dict]: |
| """ |
| Process chat conversations in batch |
| |
| Args: |
| conversations: List of message lists |
| temperature: Sampling temperature |
| max_tokens: Maximum tokens per response |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| List of conversation results |
| """ |
| results = [] |
| start_time = time.time() |
| |
| logger.info(f"Processing {len(conversations)} conversations...") |
| |
| for conv in tqdm(conversations): |
| try: |
| start = time.time() |
| response = self.client.chat( |
| messages=conv, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| **kwargs |
| ) |
| duration = time.time() - start |
| |
| results.append({ |
| 'conversation': conv, |
| 'response': response, |
| 'success': True, |
| 'duration': duration |
| }) |
| |
| except Exception as e: |
| logger.error(f"Conversation processing failed: {str(e)}") |
| results.append({ |
| 'conversation': conv, |
| 'response': None, |
| 'success': False, |
| 'error': str(e) |
| }) |
| |
| total_time = time.time() - start_time |
| successful = sum(1 for r in results if r['success']) |
| |
| logger.info(f"Processed {successful}/{len(conversations)} conversations in {total_time:.2f}s") |
| |
| return results |
| |
| def process_file( |
| self, |
| input_file: str, |
| output_file: str, |
| prompt_column: str = "prompt", |
| temperature: float = 0.7, |
| max_tokens: int = 1024, |
| **kwargs |
| ) -> pd.DataFrame: |
| """ |
| Process prompts from file |
| |
| Args: |
| input_file: Input CSV/JSON file path |
| output_file: Output file path |
| prompt_column: Column name containing prompts |
| temperature: Sampling temperature |
| max_tokens: Maximum tokens per response |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| DataFrame with results |
| """ |
| |
| input_path = Path(input_file) |
| |
| if input_path.suffix == '.csv': |
| df = pd.read_csv(input_path) |
| elif input_path.suffix == '.json': |
| df = pd.read_json(input_path) |
| else: |
| raise ValueError(f"Unsupported file format: {input_path.suffix}") |
| |
| if prompt_column not in df.columns: |
| raise ValueError(f"Column '{prompt_column}' not found in input file") |
| |
| |
| prompts = df[prompt_column].tolist() |
| results = self.process_prompts( |
| prompts, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| **kwargs |
| ) |
| |
| |
| df['response'] = [r['response'] for r in results] |
| df['success'] = [r['success'] for r in results] |
| df['duration'] = [r.get('duration', None) for r in results] |
| df['error'] = [r.get('error', None) for r in results] |
| |
| |
| output_path = Path(output_file) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| if output_path.suffix == '.csv': |
| df.to_csv(output_path, index=False) |
| elif output_path.suffix == '.json': |
| df.to_json(output_path, orient='records', indent=2) |
| else: |
| raise ValueError(f"Unsupported output format: {output_path.suffix}") |
| |
| logger.info(f"Results saved to {output_path}") |
| |
| return df |
| |
| def get_statistics(self) -> Dict: |
| """Get processing statistics""" |
| return self.stats.copy() |
|
|
|
|
| class DatasetProcessor: |
| """Process specific dataset formats""" |
| |
| def __init__(self, client: HelionClient): |
| self.client = client |
| self.processor = BatchProcessor(client) |
| |
| def process_qa_dataset( |
| self, |
| questions: List[str], |
| contexts: Optional[List[str]] = None, |
| temperature: float = 0.3, |
| max_tokens: int = 512 |
| ) -> List[Dict]: |
| """Process question-answering dataset""" |
| prompts = [] |
| |
| for i, question in enumerate(questions): |
| if contexts and i < len(contexts): |
| prompt = f"Context: {contexts[i]}\n\nQuestion: {question}\n\nAnswer:" |
| else: |
| prompt = f"Question: {question}\n\nAnswer:" |
| |
| prompts.append(prompt) |
| |
| return self.processor.process_prompts( |
| prompts, |
| temperature=temperature, |
| max_tokens=max_tokens |
| ) |
| |
| def process_code_dataset( |
| self, |
| tasks: List[str], |
| languages: Optional[List[str]] = None, |
| temperature: float = 0.2, |
| max_tokens: int = 1024 |
| ) -> List[Dict]: |
| """Process code generation tasks""" |
| prompts = [] |
| |
| for i, task in enumerate(tasks): |
| lang = languages[i] if languages and i < len(languages) else "python" |
| prompt = f"Write a {lang} function to: {task}\n\n```{lang}\n" |
| prompts.append(prompt) |
| |
| return self.processor.process_prompts( |
| prompts, |
| temperature=temperature, |
| max_tokens=max_tokens |
| ) |
| |
| def process_translation_dataset( |
| self, |
| texts: List[str], |
| source_lang: str, |
| target_lang: str, |
| temperature: float = 0.3, |
| max_tokens: int = 1024 |
| ) -> List[Dict]: |
| """Process translation tasks""" |
| prompts = [] |
| |
| for text in texts: |
| prompt = f"Translate the following text from {source_lang} to {target_lang}:\n\n{text}\n\nTranslation:" |
| prompts.append(prompt) |
| |
| return self.processor.process_prompts( |
| prompts, |
| temperature=temperature, |
| max_tokens=max_tokens |
| ) |
| |
| def process_summarization_dataset( |
| self, |
| documents: List[str], |
| max_summary_length: int = 150, |
| temperature: float = 0.5, |
| max_tokens: int = 512 |
| ) -> List[Dict]: |
| """Process document summarization""" |
| prompts = [] |
| |
| for doc in documents: |
| prompt = f"Summarize the following document in {max_summary_length} words or less:\n\n{doc}\n\nSummary:" |
| prompts.append(prompt) |
| |
| return self.processor.process_prompts( |
| prompts, |
| temperature=temperature, |
| max_tokens=max_tokens |
| ) |
|
|
|
|
| def main(): |
| """Main batch processing entry point""" |
| parser = argparse.ArgumentParser(description="Batch inference with Helion") |
| parser.add_argument("--base-url", type=str, default="http://localhost:8000") |
| parser.add_argument("--input", type=str, required=True, help="Input file (CSV/JSON)") |
| parser.add_argument("--output", type=str, required=True, help="Output file (CSV/JSON)") |
| parser.add_argument("--prompt-column", type=str, default="prompt") |
| parser.add_argument("--temperature", type=float, default=0.7) |
| parser.add_argument("--max-tokens", type=int, default=1024) |
| parser.add_argument("--batch-size", type=int, default=10) |
| |
| args = parser.parse_args() |
| |
| |
| client = HelionClient(base_url=args.base_url) |
| processor = BatchProcessor(client, batch_size=args.batch_size) |
| |
| |
| df = processor.process_file( |
| input_file=args.input, |
| output_file=args.output, |
| prompt_column=args.prompt_column, |
| temperature=args.temperature, |
| max_tokens=args.max_tokens |
| ) |
| |
| |
| stats = processor.get_statistics() |
| logger.info("\nProcessing Statistics:") |
| logger.info(f"Total requests: {stats['total']}") |
| logger.info(f"Successful: {stats['successful']}") |
| logger.info(f"Failed: {stats['failed']}") |
| logger.info(f"Total time: {stats['total_time']:.2f}s") |
| logger.info(f"Avg time per request: {stats['avg_time_per_request']:.2f}s") |
|
|
|
|
| if __name__ == "__main__": |
| main() |