Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from typing import List, Dict, Any, Optional | |
| import tempfile | |
| import re | |
| import json | |
| import requests | |
| from urllib.parse import urlparse | |
| import pytesseract | |
| from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter | |
| import cmath | |
| import pandas as pd | |
| import uuid | |
| import numpy as np | |
| from code_interpreter import CodeInterpreter | |
| import logging | |
| interpreter_instance = CodeInterpreter() | |
| from image_processing import * | |
| """Langraph""" | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.document_loaders import ArxivLoader | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ( | |
| ChatHuggingFace, | |
| HuggingFaceEndpoint, | |
| HuggingFaceEmbeddings, | |
| ) | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain.tools.retriever import create_retriever_tool | |
| from supabase.client import Client, create_client | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("agent") | |
| def tool_response(success: bool, data=None, error=None): | |
| """Standardized response format for tools.""" | |
| return { | |
| "status": "success" if success else "error", | |
| "data": data, | |
| "error": error | |
| } | |
| from typing import Any | |
| def multiply(a: Any, b: Any): | |
| """Multiply two numbers and return the product.""" | |
| logger.info("multiply called with a=%s, b=%s", a, b) | |
| try: | |
| a = float(a) | |
| b = float(b) | |
| result = a * b | |
| return tool_response(True, result) | |
| except Exception as e: | |
| logger.error("multiply failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| def add(a: Any, b: Any): | |
| """Add two numbers and return the sum.""" | |
| logger.info("add called with a=%s, b=%s", a, b) | |
| try: | |
| a = float(a) | |
| b = float(b) | |
| return tool_response(True, a + b) | |
| except Exception as e: | |
| logger.error("add failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| def subtract(a: Any, b: Any): | |
| """Subtract b from a and return the result.""" | |
| logger.info("subtract called with a=%s, b=%s", a, b) | |
| try: | |
| a = float(a) | |
| b = float(b) | |
| return tool_response(True, a - b) | |
| except Exception as e: | |
| logger.error("subtract failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| def divide(a: Any, b: Any): | |
| """Divide a by b and return the quotient.""" | |
| logger.info("divide called with a=%s, b=%s", a, b) | |
| try: | |
| a = float(a) | |
| b = float(b) | |
| if b == 0: | |
| return tool_response(False, error="Division by zero") | |
| return tool_response(True, a / b) | |
| except Exception as e: | |
| logger.error("divide failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| def modulus(a: Any, b: Any): | |
| """Return the remainder of a divided by b.""" | |
| logger.info("modulus called with a=%s, b=%s", a, b) | |
| try: | |
| a = float(a) | |
| b = float(b) | |
| return tool_response(True, a % b) | |
| except Exception as e: | |
| logger.error("modulus failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| def power(a: Any, b: Any): | |
| """Raise a to the power of b.""" | |
| logger.info("power called with a=%s, b=%s", a, b) | |
| try: | |
| a = float(a) | |
| b = float(b) | |
| return tool_response(True, a ** b) | |
| except Exception as e: | |
| logger.error("power failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| def square_root(a: Any): | |
| """Return the square root of a number.""" | |
| logger.info("square_root called with a=%s", a) | |
| try: | |
| a = float(a) | |
| if a < 0: | |
| # use complex math if negative | |
| return tool_response(True, str(cmath.sqrt(a))) | |
| return tool_response(True, a ** 0.5) | |
| except Exception as e: | |
| logger.error("square_root failed: %s", str(e)) | |
| return tool_response(False, error=f"Invalid input: {e}") | |
| # ========================= | |
| # 📂 File Tools | |
| # ========================= | |
| def save_and_read_file(filename: str, content: str): | |
| """Save content to a file and return the content back.""" | |
| logger.info("save_and_read_file called with filename=%s", filename) | |
| try: | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| with open(filename, "r", encoding="utf-8") as f: | |
| result = f.read() | |
| return tool_response(True, result) | |
| except Exception as e: | |
| logger.error("save_and_read_file failed: %s", str(e)) | |
| return tool_response(False, error=f"File error: {e}") | |
| def download_file_from_url(url: str): | |
| """Download a file from a URL and return its local path.""" | |
| logger.info("download_file_from_url called with url=%s", url) | |
| try: | |
| if url.startswith("file://"): | |
| raise ValueError("Local file:// URLs not allowed") | |
| response = requests.get(url, timeout=10) | |
| response.raise_for_status() | |
| filename = os.path.basename(urlparse(url).path) or f"download_{uuid.uuid4()}" | |
| with open(filename, "wb") as f: | |
| f.write(response.content) | |
| return tool_response(True, filename) | |
| except Exception as e: | |
| logger.error("download_file_from_url failed: %s", str(e)) | |
| return tool_response(False, error=f"Download error: {e}") | |
| # ========================= | |
| # 🖼️ Image Tools | |
| # ========================= | |
| def extract_text_from_image(image_path: str): | |
| """Extract text from an image using OCR.""" | |
| logger.info("extract_text_from_image called with image_path=%s", image_path) | |
| try: | |
| text = pytesseract.image_to_string(Image.open(image_path)) | |
| return tool_response(True, text.strip()) | |
| except Exception as e: | |
| logger.error("extract_text_from_image failed: %s", str(e)) | |
| return tool_response(False, error=f"OCR error: {e}") | |
| def analyze_image(image_path: str): | |
| """Return basic analysis (size, mode) of an image.""" | |
| logger.info("analyze_image called with image_path=%s", image_path) | |
| try: | |
| with Image.open(image_path) as img: | |
| data = {"format": img.format, "mode": img.mode, "size": img.size} | |
| return tool_response(True, data) | |
| except Exception as e: | |
| logger.error("analyze_image failed: %s", str(e)) | |
| return tool_response(False, error=f"Image analysis error: {e}") | |
| def transform_image(image_path: str, operation: str): | |
| """Apply a simple transform (grayscale, blur, sharpen).""" | |
| logger.info("transform_image called with image_path=%s operation=%s", image_path, operation) | |
| try: | |
| img = Image.open(image_path) | |
| if operation == "grayscale": | |
| img = img.convert("L") | |
| elif operation == "blur": | |
| img = img.filter(ImageFilter.BLUR) | |
| elif operation == "sharpen": | |
| img = img.filter(ImageFilter.SHARPEN) | |
| else: | |
| raise ValueError(f"Unsupported operation: {operation}") | |
| output_path = f"transformed_{uuid.uuid4()}.png" | |
| img.save(output_path) | |
| return tool_response(True, output_path) | |
| except Exception as e: | |
| logger.error("transform_image failed: %s", str(e)) | |
| return tool_response(False, error=f"Transform error: {e}") | |
| def draw_on_image(image_path: str, text: str): | |
| """Draw text on an image.""" | |
| logger.info("draw_on_image called with image_path=%s text=%s", image_path, text) | |
| try: | |
| img = Image.open(image_path) | |
| draw = ImageDraw.Draw(img) | |
| draw.text((10, 10), text, fill="black") | |
| output_path = f"drawn_{uuid.uuid4()}.png" | |
| img.save(output_path) | |
| return tool_response(True, output_path) | |
| except Exception as e: | |
| logger.error("draw_on_image failed: %s", str(e)) | |
| return tool_response(False, error=f"Draw error: {e}") | |
| def generate_simple_image(text: str): | |
| """Generate a simple image with text.""" | |
| logger.info("generate_simple_image called with text=%s", text) | |
| try: | |
| img = Image.new("RGB", (200, 100), color="white") | |
| draw = ImageDraw.Draw(img) | |
| draw.text((10, 40), text, fill="black") | |
| output_path = f"generated_{uuid.uuid4()}.png" | |
| img.save(output_path) | |
| return tool_response(True, output_path) | |
| except Exception as e: | |
| logger.error("generate_simple_image failed: %s", str(e)) | |
| return tool_response(False, error=f"Image generation error: {e}") | |
| def combine_images(image1_path: str, image2_path: str): | |
| """Combine two images side by side.""" | |
| logger.info("combine_images called with %s and %s", image1_path, image2_path) | |
| try: | |
| img1 = Image.open(image1_path) | |
| img2 = Image.open(image2_path) | |
| combined = Image.new("RGB", (img1.width + img2.width, max(img1.height, img2.height))) | |
| combined.paste(img1, (0, 0)) | |
| combined.paste(img2, (img1.width, 0)) | |
| output_path = f"combined_{uuid.uuid4()}.png" | |
| combined.save(output_path) | |
| return tool_response(True, output_path) | |
| except Exception as e: | |
| logger.error("combine_images failed: %s", str(e)) | |
| return tool_response(False, error=f"Combine error: {e}") | |
| # ========================= | |
| # 📊 Data Tools | |
| # ========================= | |
| def analyze_csv_file(file_path: str): | |
| """Analyze a CSV file and return basic info.""" | |
| logger.info("analyze_csv_file called with file_path=%s", file_path) | |
| try: | |
| df = pd.read_csv(file_path) | |
| summary = {"shape": df.shape, "columns": df.columns.tolist(), "head": df.head(3).to_dict()} | |
| return tool_response(True, summary) | |
| except Exception as e: | |
| logger.error("analyze_csv_file failed: %s", str(e)) | |
| return tool_response(False, error=f"CSV analysis error: {e}") | |
| def analyze_excel_file(file_path: str): | |
| """Analyze an Excel file and return basic info.""" | |
| logger.info("analyze_excel_file called with file_path=%s", file_path) | |
| try: | |
| df = pd.read_excel(file_path) | |
| summary = {"shape": df.shape, "columns": df.columns.tolist(), "head": df.head(3).to_dict()} | |
| return tool_response(True, summary) | |
| except Exception as e: | |
| logger.error("analyze_excel_file failed: %s", str(e)) | |
| return tool_response(False, error=f"Excel analysis error: {e}") | |
| # ========================= | |
| # 💻 Code Tool | |
| # ========================= | |
| def execute_code_multilang(code: str, language: str = "python"): | |
| """Execute code in multiple languages using CodeInterpreter.""" | |
| logger.info("execute_code_multilang called with language=%s", language) | |
| try: | |
| result = interpreter_instance.execute_code(code, language) | |
| return tool_response(True, result) | |
| except Exception as e: | |
| logger.error("execute_code_multilang failed: %s", str(e)) | |
| return tool_response(False, error=f"Code execution error: {e}") | |
| # ========================= | |
| # 🌍 Search Tools | |
| # ========================= | |
| def web_search(query: str, max_results: int = 3): | |
| """Perform a web search using TavilySearchResults.""" | |
| logger.info("web_search called with query=%s", query) | |
| try: | |
| tavily = TavilySearchResults(max_results=max_results) | |
| results = tavily.invoke(query) | |
| return tool_response(True, results) | |
| except Exception as e: | |
| logger.error("web_search failed: %s", str(e)) | |
| return tool_response(False, error=f"Web search error: {e}") | |
| def wiki_search(query: str): | |
| """Search Wikipedia and return documents.""" | |
| logger.info("wiki_search called with query=%s", query) | |
| try: | |
| loader = WikipediaLoader(query=query, load_max_docs=3) | |
| docs = loader.load() | |
| results = [doc.page_content for doc in docs] | |
| return tool_response(True, results) | |
| except Exception as e: | |
| logger.error("wiki_search failed: %s", str(e)) | |
| return tool_response(False, error=f"Wikipedia error: {e}") | |
| def arxiv_search(query: str): | |
| """Search Arxiv and return documents.""" | |
| logger.info("arxiv_search called with query=%s", query) | |
| try: | |
| loader = ArxivLoader(query=query, load_max_docs=3) | |
| docs = loader.load() | |
| results = [doc.page_content for doc in docs] | |
| return tool_response(True, results) | |
| except Exception as e: | |
| logger.error("arxiv_search failed: %s", str(e)) | |
| return tool_response(False, error=f"Arxiv error: {e}") | |
| if __name__ == "__main__": | |
| logger.info("=== Running Tool Tests ===") | |
| # ========================= | |
| # 🌍 Tested for tools | |
| # ========================= | |
| # 🌍 Search Tools | |
| # print("\n--- web_search ---") | |
| # print(web_search.invoke({"query": "latest AI research", "max_results": 2})) | |
| # print("\n--- wiki_search ---") | |
| # print(wiki_search.invoke({"query": "LangChain"})) | |
| # print("\n--- arxiv_search ---") | |
| # print(arxiv_search.invoke({"query": "transformers"})) | |
| # 💻 Code Execution | |
| # print("\n--- execute_code_multilang ---") | |
| # print(execute_code_multilang.invoke({"code": "print(2+3)", "language": "python"})) | |
| # load the system prompt from the file | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| print(system_prompt) | |
| # System message | |
| sys_msg = SystemMessage(content=system_prompt) | |
| # build a retriever | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2" | |
| ) # dim=768 | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| supabase_url = os.environ.get("SUPABASE_URL") | |
| supabase_key = os.environ.get("SUPABASE_KEY") | |
| supabase: Client = create_client( | |
| supabase_url, supabase_key | |
| ) | |
| vector_store = SupabaseVectorStore( | |
| client=supabase, | |
| embedding=embeddings, | |
| table_name="documents2", | |
| query_name="match_documents_2", | |
| ) | |
| create_retriever_tool = create_retriever_tool( | |
| retriever=vector_store.as_retriever(), | |
| name="Question Search", | |
| description="A tool to retrieve similar questions from a vector store.", | |
| ) | |
| tools = [ | |
| web_search, | |
| wiki_search, | |
| arxiv_search, | |
| multiply, | |
| add, | |
| subtract, | |
| divide, | |
| modulus, | |
| power, | |
| square_root, | |
| save_and_read_file, | |
| download_file_from_url, | |
| extract_text_from_image, | |
| analyze_csv_file, | |
| analyze_excel_file, | |
| execute_code_multilang, | |
| analyze_image, | |
| transform_image, | |
| draw_on_image, | |
| generate_simple_image, | |
| combine_images, | |
| ] | |
| # Build graph function | |
| def build_graph(provider: str = "groq"): | |
| """Build the graph""" | |
| # Load environment variables from .env file | |
| if provider == "groq": | |
| # Groq https://console.groq.com/docs/models | |
| llm = ChatGroq(model="qwen/qwen3-32b", temperature=0) | |
| elif provider == "huggingface": | |
| # TODO: Add huggingface endpoint | |
| llm = ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| task="text-generation", # for chat‐style use “text-generation” | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| repetition_penalty=1.03, | |
| temperature=0, | |
| ), | |
| verbose=True, | |
| ) | |
| else: | |
| raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.") | |
| # Bind tools to LLM | |
| llm_with_tools = llm.bind_tools(tools) | |
| # Node | |
| def assistant(state: MessagesState): | |
| """Assistant node""" | |
| return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
| def retriever(state: MessagesState): | |
| """Retriever node""" | |
| similar_question = vector_store.similarity_search(state["messages"][0].content) | |
| if similar_question: # Check if the list is not empty | |
| example_msg = HumanMessage( | |
| content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}", | |
| ) | |
| return {"messages": [sys_msg] + state["messages"] + [example_msg]} | |
| else: | |
| # Handle the case when no similar questions are found | |
| return {"messages": [sys_msg] + state["messages"]} | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("retriever", retriever) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| # Compile graph | |
| return builder.compile() | |
| # test | |
| if __name__ == "__main__": | |
| question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" | |
| graph = build_graph(provider="groq") | |
| messages = [HumanMessage(content=question)] | |
| messages = graph.invoke({"messages": messages}) | |
| for m in messages["messages"]: | |
| m.pretty_print() | |