File size: 5,606 Bytes
5c60ed2
0bf6060
f8adcff
90d1e52
0635997
7c12ef4
d5c54ef
037376c
0635997
90d1e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0635997
f846748
 
 
35f9142
 
60ac7f7
90d1e52
60ac7f7
 
f846748
0635997
 
79c456d
0635997
90d1e52
0635997
8661441
d5c54ef
90d1e52
 
e6d12c5
90d1e52
a6051b9
90d1e52
 
a1e734a
90d1e52
 
 
 
f7848c9
90d1e52
 
a6051b9
90d1e52
 
4e67249
90d1e52
 
 
 
 
 
 
34f414b
a2933d7
f846748
 
 
 
 
 
 
 
 
 
 
90d1e52
 
f846748
 
 
 
 
 
 
 
 
 
 
b7d6ba3
d2eb5fb
a6051b9
f846748
 
 
 
 
 
 
4c92796
f846748
cddcba8
f846748
 
 
 
cdec1a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime

from smolagents import Tool, HfApiModel, ToolCallingAgent
from langchain_core.vectorstores import VectorStore


class RetrieverTool(Tool):
    name = "retriever"
    description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
        }
    }
    output_type = "string"

    def __init__(self, vectordb: VectorStore, **kwargs):
        super().__init__(**kwargs)
        self.vectordb = vectordb

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        docs = self.vectordb.similarity_search(
            query,
            k=7,
        )

        df = pd.read_csv("bger_cedh_db 1954-2024.csv")
        
        spacer = " \n"
        context = ""
        nb_char = 100
        
        for doc in docs:
            case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
            index = case_text.find(doc.page_content)
            start = max(0, index - nb_char)
            end = min(len(case_text), index + len(doc.page_content) + nb_char)
            case_text_summary = case_text[start:end]
            
            context += "#######" + spacer
            context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
            context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
            context += "# Case date: " + doc.metadata["case_date"] + spacer
            context += "# Case url: " + doc.metadata["case_url"] + spacer
            #context += "# Case text: " + doc.page_content + spacer
            context += "# Case extract: " + case_text_summary + spacer


        return "\nRetrieved documents:\n" + context


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)

model = "meta-llama/Meta-Llama-3-8B-Instruct"

client = InferenceClient(model)

folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)

df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")

retriever_tool = RetrieverTool(vector_db)
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))

def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):

    print(datetime.now())
    context = retriever_tool(question)
    
    prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
    Respond only to the question asked, response should be concise and relevant to the question and answer in the same language as the question.
    Provide the number of the source document when relevant, as well as the link to the document.
    If you cannot find information, do not give up and try calling your retriever again with different arguments!
    
    Question:
    {question}
    
    {context}
    """
    
    messages = [{"role": "user", "content": prompt}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        
#        answer = client.chat_completion(messages, temperature=0.1).choices[0].message.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
        gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
    ],
    description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)


if __name__ == "__main__":
    demo.launch(debug=True)