agent_RAG / app.py
umaiku's picture
Update app.py
6e990be verified
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime
from smolagents import Tool, HfApiModel, ToolCallingAgent
from langchain_core.vectorstores import VectorStore
class RetrieverTool(Tool):
name = "retriever"
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, vectordb: VectorStore, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
query,
k=7,
)
spacer = " \n"
context = ""
nb_char = 100
for doc in docs:
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
index = case_text.find(doc.page_content)
start = max(0, index - nb_char)
end = min(len(case_text), index + len(doc.page_content) + nb_char)
case_text_summary = case_text[start:end]
context += "#######" + spacer
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
context += "# Case date: " + doc.metadata["case_date"] + spacer
context += "# Case url: " + doc.metadata["case_url"] + spacer
#context += "# Case text: " + doc.page_content + spacer
context += "# Case extract: " + case_text_summary + spacer
return "\nRetrieved documents:\n" + context
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)
model = "meta-llama/Meta-Llama-3-8B-Instruct"
#model = "swiss-ai/Apertus-8B-Instruct-2509"
client = InferenceClient(model)
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
df = pd.read_csv("bger_cedh_db 1954-2024.csv")
retriever_tool = RetrieverTool(vector_db)
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
print(datetime.now())
context = retriever_tool(message)
print(message)
# is_law = client.text_generation(f"""Given the user question below, classify it as either being about "Law" or "Other".
#Do NOT respond with more than one word.
#Question:
#{message}""")
# print(is_law)
if True: #is_law.lower() != "other":
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
Respond only to the question asked, response should be relevant to the question and in the same language as the question.
Provide the number of the source document when relevant, as well as the link to the document.
If you cannot find information, do not give up and try calling your retriever again with different arguments!
Always give url of the sources at the end and only answer in the language the question is asked.
Question:
{message}
{context}
"""
else:
prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message:
{message}"""
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": prompt})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
],
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
print("Ready!")
demo.launch(debug=True)