|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from huggingface_hub import InferenceClient, login, snapshot_download |
|
|
from langchain_community.vectorstores import FAISS, DistanceStrategy |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
import os |
|
|
import pandas as pd |
|
|
from datetime import datetime |
|
|
|
|
|
from smolagents import Tool, HfApiModel, ToolCallingAgent |
|
|
from langchain_core.vectorstores import VectorStore |
|
|
|
|
|
|
|
|
class RetrieverTool(Tool): |
|
|
name = "retriever" |
|
|
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query." |
|
|
inputs = { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
|
|
} |
|
|
} |
|
|
output_type = "string" |
|
|
|
|
|
def __init__(self, vectordb: VectorStore, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vectordb = vectordb |
|
|
|
|
|
def forward(self, query: str) -> str: |
|
|
assert isinstance(query, str), "Your search query must be a string" |
|
|
|
|
|
docs = self.vectordb.similarity_search( |
|
|
query, |
|
|
k=7, |
|
|
) |
|
|
|
|
|
spacer = " \n" |
|
|
context = "" |
|
|
nb_char = 100 |
|
|
|
|
|
for doc in docs: |
|
|
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0] |
|
|
index = case_text.find(doc.page_content) |
|
|
start = max(0, index - nb_char) |
|
|
end = min(len(case_text), index + len(doc.page_content) + nb_char) |
|
|
case_text_summary = case_text[start:end] |
|
|
|
|
|
context += "#######" + spacer |
|
|
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer |
|
|
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer |
|
|
context += "# Case date: " + doc.metadata["case_date"] + spacer |
|
|
context += "# Case url: " + doc.metadata["case_url"] + spacer |
|
|
|
|
|
context += "# Case extract: " + case_text_summary + spacer |
|
|
|
|
|
|
|
|
return "\nRetrieved documents:\n" + context |
|
|
|
|
|
|
|
|
""" |
|
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
|
|
""" |
|
|
HF_TOKEN=os.getenv('TOKEN') |
|
|
login(HF_TOKEN) |
|
|
|
|
|
model = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
|
|
|
|
|
|
client = InferenceClient(model) |
|
|
|
|
|
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd()) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
|
|
|
|
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE) |
|
|
|
|
|
df = pd.read_csv("bger_cedh_db 1954-2024.csv") |
|
|
|
|
|
retriever_tool = RetrieverTool(vector_db) |
|
|
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model)) |
|
|
|
|
|
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,): |
|
|
|
|
|
print(datetime.now()) |
|
|
context = retriever_tool(message) |
|
|
|
|
|
print(message) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
|
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question. |
|
|
Respond only to the question asked, response should be relevant to the question and in the same language as the question. |
|
|
Provide the number of the source document when relevant, as well as the link to the document. |
|
|
If you cannot find information, do not give up and try calling your retriever again with different arguments! |
|
|
Always give url of the sources at the end and only answer in the language the question is asked. |
|
|
|
|
|
Question: |
|
|
{message} |
|
|
|
|
|
{context} |
|
|
""" |
|
|
else: |
|
|
prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message: |
|
|
{message}""" |
|
|
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
|
|
for val in history: |
|
|
if val[0]: |
|
|
messages.append({"role": "user", "content": val[0]}) |
|
|
if val[1]: |
|
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
response = "" |
|
|
|
|
|
|
|
|
for message in client.chat_completion( |
|
|
messages, |
|
|
max_tokens=max_tokens, |
|
|
stream=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
): |
|
|
token = message.choices[0].delta.content |
|
|
|
|
|
response += token |
|
|
yield response |
|
|
|
|
|
|
|
|
""" |
|
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
|
""" |
|
|
demo = gr.ChatInterface( |
|
|
respond, |
|
|
additional_inputs=[ |
|
|
gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"), |
|
|
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"), |
|
|
gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.95, |
|
|
step=0.05, |
|
|
label="Top-p (nucleus sampling)", |
|
|
), |
|
|
gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"), |
|
|
], |
|
|
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Ready!") |
|
|
demo.launch(debug=True) |