basic_agent / test_llm.py
techy-ai
basic agent
47bae79
raw
history blame
8.96 kB
#!/usr/bin/env python3
"""
Complete LLM Testing Script
Supports Groq and local HuggingFace LLMs with proper LangChain integration.
"""
import os
import sys
from dotenv import load_dotenv
# LangChain & LangGraph imports
try:
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_groq import ChatGroq
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
print("βœ… LangChain imports successful")
except ImportError as e:
print(f"❌ Import error: {e}")
print("πŸ’‘ Install missing packages: pip install langchain-groq langgraph")
sys.exit(1)
load_dotenv()
class LocalHuggingFaceLLM:
"""Custom wrapper for local HuggingFace models"""
def __init__(self, model, tokenizer, device):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.eval()
def invoke(self, messages):
"""Generate response from local model, return AIMessage"""
from langchain_core.messages import AIMessage
import torch
# Convert messages to text
if isinstance(messages, list):
text = ""
for msg in messages:
if hasattr(msg, 'content'):
if hasattr(msg, 'type'):
if msg.type == "system":
text += f"System: {msg.content}\n"
elif msg.type == "human":
text += f"Human: {msg.content}\n"
else:
text += f"{msg.content}\n"
else:
text += f"Human: {msg.content}\n"
else:
text += str(msg) + "\n"
text += "Assistant:"
else:
text = str(messages)
try:
inputs = self.tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
if self.device == "cuda" and torch.cuda.is_available():
inputs = inputs.to(self.device)
self.model = self.model.to(self.device)
outputs = self.model.generate(
inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
attention_mask=torch.ones_like(inputs),
no_repeat_ngram_size=2,
early_stopping=True
)
response_text = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()
return AIMessage(content=response_text if response_text else "I understand.")
except Exception as e:
return AIMessage(content=f"Error generating response: {str(e)}")
def create_local_huggingface_llm():
"""Initialize local HuggingFace model"""
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "microsoft/DialoGPT-small"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
return LocalHuggingFaceLLM(model, tokenizer, device)
except Exception as e:
print(f"❌ Failed to load local HuggingFace model: {e}")
return None
def create_minimal_graph(provider: str = "groq"):
"""Create a minimal graph for testing"""
try:
if provider == "groq":
if not os.getenv("GROQ_API_KEY"):
raise ValueError("GROQ_API_KEY not found")
llm = ChatGroq(model="qwen/qwen3-32b", temperature=0)
def assistant(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_edge(START, "assistant")
return builder.compile()
elif provider == "huggingface_local":
llm = create_local_huggingface_llm()
if llm is None:
raise ValueError("Failed to create local HuggingFace model")
def assistant(state: MessagesState):
# Return AIMessage directly
return {"messages": [llm.invoke(state["messages"])]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_edge(START, "assistant")
return builder.compile()
else:
raise ValueError(f"Unknown provider: {provider}")
except Exception as e:
print(f"❌ Failed to create minimal graph: {e}")
return None
def test_basic_llm_response(provider: str = "groq"):
"""Test basic LLM response"""
print(f"\nπŸ§ͺ Testing Basic LLM Response ({provider})")
try:
if provider == "groq":
if not os.getenv("GROQ_API_KEY"):
return {"status": "error", "error": "GROQ_API_KEY not found"}
llm = ChatGroq(model="qwen/qwen3-32b", temperature=0)
elif provider == "huggingface_local":
llm = create_local_huggingface_llm()
if llm is None:
return {"status": "error", "error": "Failed to create local HuggingFace model"}
else:
return {"status": "error", "error": f"Unknown provider: {provider}"}
test_message = "Hello! Please respond with 'LLM is working correctly'"
response = llm.invoke([HumanMessage(content=test_message)])
print(f"πŸ“₯ Response: {response.content[:200]}")
return {"status": "success", "provider": provider, "response": response.content}
except Exception as e:
return {"status": "error", "error": str(e)}
def test_llm_with_system_prompt(provider: str = "groq"):
"""Test LLM with system prompt"""
print(f"\nπŸ§ͺ Testing LLM with System Prompt ({provider})")
try:
if provider == "groq":
llm = ChatGroq(model="qwen/qwen3-32b", temperature=0)
elif provider == "huggingface_local":
llm = create_local_huggingface_llm()
if llm is None:
return {"status": "error", "error": "Failed to create local HuggingFace model"}
else:
return {"status": "error", "error": f"Unknown provider: {provider}"}
system_msg = SystemMessage(content="You are a helpful assistant. Answer briefly and clearly.")
user_msg = HumanMessage(content="What is 2+2? Just give me the number.")
response = llm.invoke([system_msg, user_msg])
print(f"πŸ“₯ Response: {response.content}")
return {"status": "success", "provider": provider, "response": response.content}
except Exception as e:
return {"status": "error", "error": str(e)}
def test_graph_workflow(provider: str = "groq"):
"""Test graph workflow"""
print(f"\nπŸ§ͺ Testing Graph Workflow ({provider})")
try:
graph = create_minimal_graph(provider)
if graph is None:
return {"status": "error", "error": "Failed to create graph"}
test_query = "What is 5 + 3? Just give me the answer."
result = graph.invoke({"messages": [HumanMessage(content=test_query)]})
if result and "messages" in result:
last_message = result["messages"][-1]
print(f"πŸ“₯ Final response: {last_message.content}")
return {"status": "success", "response": last_message.content, "message_count": len(result["messages"])}
else:
return {"status": "error", "error": "No valid response from graph"}
except Exception as e:
return {"status": "error", "error": str(e)}
def run_all_tests():
"""Run all LLM tests"""
results = {}
# Groq tests
results["groq_basic"] = test_basic_llm_response("groq")
results["groq_system_prompt"] = test_llm_with_system_prompt("groq")
results["groq_graph"] = test_graph_workflow("groq")
# HuggingFace local tests
results["huggingface_local_basic"] = test_basic_llm_response("huggingface_local")
results["huggingface_local_system_prompt"] = test_llm_with_system_prompt("huggingface_local")
results["huggingface_local_graph"] = test_graph_workflow("huggingface_local")
return results
if __name__ == "__main__":
test_results = run_all_tests()
print("\nπŸ“Š Test Results:")
for k, v in test_results.items():
print(f"{k}: {v}")