"""Model wrapper for LiteLLM""" import os import json from typing import List, Dict, Any, Optional try: import litellm except ImportError: print("⚠️ litellm not installed. Install with: pip install litellm") litellm = None class LiteLLMModel: """Wrapper for LiteLLM models""" def __init__(self, model_id: str): self.model_id = model_id # Check for Groq API key if "groq" in model_id.lower(): if not os.getenv("GROQ_API_KEY"): print("⚠️ GROQ_API_KEY not set in environment") raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.") def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict: if not litellm: return {"content": "Unknown - litellm not installed"} try: formatted_tools = None if tools: formatted_tools = [ { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.parameters } } for tool in tools ] # Groq configuration if "groq" in self.model_id.lower(): api_key = os.getenv("GROQ_API_KEY") if not api_key: raise RuntimeError("GROQ_API_KEY not set in environment") print(f"DEBUG: Using Groq model: {self.model_id}") response = litellm.completion( model=self.model_id, api_key=api_key, messages=messages, tools=formatted_tools, temperature=0.1 ) else: # Generic model support response = litellm.completion( model=self.model_id, messages=messages, tools=formatted_tools, temperature=0.1 ) message = response.choices[0].message result = { "content": message.content or "" } if hasattr(message, 'tool_calls') and message.tool_calls: result["tool_calls"] = [] for tc in message.tool_calls: # Parse arguments if they're a string args = tc.function.arguments if isinstance(args, str): try: args = json.loads(args) except: args = {} result["tool_calls"].append({ "id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}", "name": tc.function.name, "arguments": args }) return result except Exception as e: print(f"Model error: {e}") return {"content": "Unknown"} def get_model(model_type: str, model_id: str): if model_type == "LiteLLMModel": return LiteLLMModel(model_id) else: raise ValueError(f"Unknown model type: {model_type}")