agent_RAG / app.py
umaiku's picture
Update app.py
90d1e52 verified
raw
history blame
5.61 kB
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime
from smolagents import Tool, HfApiModel, ToolCallingAgent
from langchain_core.vectorstores import VectorStore
class RetrieverTool(Tool):
name = "retriever"
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, vectordb: VectorStore, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
query,
k=7,
)
df = pd.read_csv("bger_cedh_db 1954-2024.csv")
spacer = " \n"
context = ""
nb_char = 100
for doc in docs:
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
index = case_text.find(doc.page_content)
start = max(0, index - nb_char)
end = min(len(case_text), index + len(doc.page_content) + nb_char)
case_text_summary = case_text[start:end]
context += "#######" + spacer
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
context += "# Case date: " + doc.metadata["case_date"] + spacer
context += "# Case url: " + doc.metadata["case_url"] + spacer
#context += "# Case text: " + doc.page_content + spacer
context += "# Case extract: " + case_text_summary + spacer
return "\nRetrieved documents:\n" + context
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)
model = "meta-llama/Meta-Llama-3-8B-Instruct"
client = InferenceClient(model)
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
retriever_tool = RetrieverTool(vector_db)
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
print(datetime.now())
context = retriever_tool(question)
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
Respond only to the question asked, response should be concise and relevant to the question and answer in the same language as the question.
Provide the number of the source document when relevant, as well as the link to the document.
If you cannot find information, do not give up and try calling your retriever again with different arguments!
Question:
{question}
{context}
"""
messages = [{"role": "user", "content": prompt}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
# answer = client.chat_completion(messages, temperature=0.1).choices[0].message.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
],
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
demo.launch(debug=True)