File size: 5,606 Bytes
5c60ed2 0bf6060 f8adcff 90d1e52 0635997 7c12ef4 d5c54ef 037376c 0635997 90d1e52 0635997 f846748 35f9142 60ac7f7 90d1e52 60ac7f7 f846748 0635997 79c456d 0635997 90d1e52 0635997 8661441 d5c54ef 90d1e52 e6d12c5 90d1e52 a6051b9 90d1e52 a1e734a 90d1e52 f7848c9 90d1e52 a6051b9 90d1e52 4e67249 90d1e52 34f414b a2933d7 f846748 90d1e52 f846748 b7d6ba3 d2eb5fb a6051b9 f846748 4c92796 f846748 cddcba8 f846748 cdec1a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime
from smolagents import Tool, HfApiModel, ToolCallingAgent
from langchain_core.vectorstores import VectorStore
class RetrieverTool(Tool):
name = "retriever"
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, vectordb: VectorStore, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
query,
k=7,
)
df = pd.read_csv("bger_cedh_db 1954-2024.csv")
spacer = " \n"
context = ""
nb_char = 100
for doc in docs:
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
index = case_text.find(doc.page_content)
start = max(0, index - nb_char)
end = min(len(case_text), index + len(doc.page_content) + nb_char)
case_text_summary = case_text[start:end]
context += "#######" + spacer
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
context += "# Case date: " + doc.metadata["case_date"] + spacer
context += "# Case url: " + doc.metadata["case_url"] + spacer
#context += "# Case text: " + doc.page_content + spacer
context += "# Case extract: " + case_text_summary + spacer
return "\nRetrieved documents:\n" + context
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)
model = "meta-llama/Meta-Llama-3-8B-Instruct"
client = InferenceClient(model)
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
retriever_tool = RetrieverTool(vector_db)
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
print(datetime.now())
context = retriever_tool(question)
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
Respond only to the question asked, response should be concise and relevant to the question and answer in the same language as the question.
Provide the number of the source document when relevant, as well as the link to the document.
If you cannot find information, do not give up and try calling your retriever again with different arguments!
Question:
{question}
{context}
"""
messages = [{"role": "user", "content": prompt}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
# answer = client.chat_completion(messages, temperature=0.1).choices[0].message.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
],
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
demo.launch(debug=True) |