File size: 12,107 Bytes
7715c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00c926b
383d8e9
7715c35
 
 
 
b4b62c6
7715c35
 
 
 
 
 
d5df257
00c926b
 
d5df257
 
00c926b
 
7715c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from langchain_community.utilities import SQLDatabase
from dotenv import load_dotenv
from langchain_community.agent_toolkits import create_sql_agent
from langchain.agents import AgentType
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
import os
import ast
import Modules
import csv
import re
import json
from operator import itemgetter
from typing import List
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_core.memory import BaseMemory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
# A remettre pour python 10
#from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel , Field
from langchain_core.runnables import (
    RunnableLambda,
    ConfigurableFieldSpec,
    RunnablePassthrough,
)

from langchain_core.runnables.history import RunnableWithMessageHistory
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import sqlalchemy
from sqlalchemy.exc import DatabaseError

# Load the API_Keys
load_dotenv()
#langsmith_api_key = os.environ.get("LANGSMITH_API_key")
groq_api_key = os.getenv("GROQ_API_KEY")

model = ChatGroq(temperature=0,
                 groq_api_key=groq_api_key,
                 #model_name="llama3-8b-8192"
                 model_name="llama3-70b-8192" 
                 )
try:
    db = SQLDatabase.from_uri(f"sqlite:///{Modules.DATA_BASE}", sample_rows_in_table_info=5)
except DatabaseError as e:
    print(f"Database error: {e}")
    print(f"Database path: {Modules.DATA_BASE}")
except sqlalchemy.exc.SQLAlchemyError as e:
    print(f"SQLAlchemy error: {e}")

custom_prompt = PromptTemplate(
    input_variables=['agent_scratchpad', 'input', 'history'],  # Ajout de 'history' comme variable d'entrée
    partial_variables={
        'tools': "sql_db_query - Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\nsql_db_schema - Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3\nsql_db_list_tables - Input is an empty string, output is a comma-separated list of tables in the database.\nsql_db_query_checker - Use this tool to double-check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!", 
        'tool_names': 'sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker'
    }, 
    template="You are an expert in statistics designed to interact with a SQL database containing Moroccan census data.\n\
There is only one primary table in the database called `Table_recensement`, which contains population census data. To get the distinct modalities of any field in this table,\n\
you must query the table with the same name as the field. For example, if the field is `Age`, you must query the table `Age`.\nGiven an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n\
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 50 results.\n\
If you can't find the query, return empty. You can order the results by a relevant column to return the most interesting examples in the database.\n\
Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n\
You have access to tools for interacting with the database.\nOnly use the tools below. Only use the information returned by the tools below to construct your final answer.\n\
You MUST double-check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n\n\
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP, etc.) to the database.\n\n\
If the question does not seem related to the database, just return I don't know as the answer.\n\n\
If relevant, you can refer to previous parts of the conversation to enrich your answer.\n\n\n{tools}\n\n\
Use the following format:\n\nQuestion: the input question you must answer\n\n\
Thought: you should always think about what to do\n\
Action: the action to take, should be one of [{tool_names}]\n\n\
Action Input: the input to the action\n\
Observation: the result of the action\n... (this Thought/Action/Action Input/Observation can repeat N times)\n\n\
Thought: I now know the final answer\n\n\
Final Answer: The final answer must be in the same language as the input. Detect the language of the input and respond in the same language.\n\n\
Begin!\n\n\
Relevant pieces of previous conversation (optional): {history}\n\
Question: {input}\n\
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant fields and retrieve\n\
the distinct modalities for any given field using `SELECT DISTINCT XXXX FROM FieldName`.\n\
{agent_scratchpad}"
)



# Create Chatbot

class InMemoryHistory(BaseChatMessageHistory, BaseMemory):
    """In memory implementation of chat message history."""

    messages: List[BaseMessage] = Field(default_factory=list)

    def add_messages(self, messages: List[BaseMessage]) -> None:
        """Add a list of messages to the store"""
        self.messages.extend(messages)

    def clear(self) -> None:
        """Clear the message history"""
        self.messages = []

    @property
    def memory_variables(self) -> List[str]:
        """Return a list of memory variables."""
        # This defines what variables the memory will return
        return ["history"]

    def load_memory_variables(self, inputs: dict) -> dict:
        """Return the history messages as a string."""
        # Convert the message history into a format suitable for the agent to use
        history_str = "\n".join([message.content for message in self.messages])
        return {"history": history_str}

    def save_context(self, inputs: dict, outputs: dict) -> None:
        """Save the context of the conversation."""
        # Save input and output messages
        user_message = inputs.get("input", None)
        ai_message = outputs.get("output", None)

        if user_message:
            self.add_messages([HumanMessage(content=user_message)])
        if ai_message:
            self.add_messages([AIMessage(content=ai_message)])


# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store = {}


def get_session_history(
    user_id: str, conversation_id: str
) -> BaseChatMessageHistory:
    global store
    if (user_id, conversation_id) not in store:
        store[(user_id, conversation_id)] = InMemoryHistory()
    return store[(user_id, conversation_id)]

history = get_session_history(str(Modules.USER_ID),str(Modules.CONVERSATION_ID))
history.add_message(AIMessage(content="hello"))


# creation of agent
agent_executor = create_sql_agent(
    model, 
    db=db, 
    get_session_history=get_session_history,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, 
    verbose=True,
    return_intermediate_steps=True,
    handle_parsing_errors=True,  # Enable handling of parsing errors
    prompt=custom_prompt,
    history_factory_config=[
        ConfigurableFieldSpec(
            id="user_id",
            annotation=str,
            name="User ID",
            description="Unique identifier for the user.",
            default=str(Modules.USER_ID),
            is_shared=True,
        ),
        ConfigurableFieldSpec(
            id="conversation_id",
            annotation=str,
            name="Conversation ID",
            description="Unique identifier for the conversation.",
            default=str(Modules.CONVERSATION_ID),
            is_shared=True,
        ),
    ],
    
    #max_iterations=3,
)

agent_executor.memory = history
agent_executor.return_intermediate_steps=True


def incrementer_conversation_id() :
    Modules.CONVERSATION_ID = Modules.CONVERSATION_ID + 1

def get_response(query):
    """Fonction pour obtenir la réponse du modèle."""
    global history
    global agent_executor
    if not query:
        return "", history

    result = agent_executor.invoke(
            {"input": query},
            config={"configurable": {"user_id": str(Modules.USER_ID), "conversation_id": str(Modules.CONVERSATION_ID)}}
            )
    
    return result, history
    
def clear_memory():
    """Fonction pour réinitialiser la mémoire."""
    global history
    global agent_executor

    incrementer_conversation_id()
    history = get_session_history(str(Modules.USER_ID),str(Modules.CONVERSATION_ID))
    agent_executor.memory = history
    

def Parser_log_action(log) :
    # Utilisation de regex pour extraire les parties nécessaires
    question_match = re.search(r'Question: (.+)', log)
    thought_match = re.search(r'Thought: (.+)', log)
    action_match = re.search(r'Action: (.+)', log)
    action_input_match = re.search(r'Action Input: (.*)', log)
    
    # Création du dictionnaire avec les résultats
    parsed_log = {
        'Question': question_match.group(1) if question_match else None,
        'Thought': thought_match.group(1) if thought_match else None,
        'Action': action_match.group(1) if action_match else None,
        'Action Input': action_input_match.group(1) if action_input_match else None
    }

    # Affichage du dictionnaire
    return parsed_log
    
def save_sql_query_input_output_steps(result_agent,filename=Modules.FILE_HISTORY) :
    """Fonction pour enregistrer l'historique des requêtes sql dans un fichier, l'input, output et les étapes"""
        # Accéder au résultat de la requête SQL
    steps = result_agent['intermediate_steps']
    input = result_agent['input']
    output = result_agent['output']
    logs =[]
    for action, result in steps:
        log=Parser_log_action(action.log)
        log['Result'] = result
        logs.append(log)
        if action.tool == 'sql_db_query':
             # Supprimer les retours à la ligne dans action.tool_input
            sql_query = action.tool_input.replace('\n', '').replace('\r', '')
            # Convertir la chaîne de caractères en une structure Python sécurisée
            try:
                # Utilisez json.loads si le résultat est au format JSON, sinon continuez avec ast.literal_eval
                try:
                    result_list = json.loads(result)
                except json.JSONDecodeError:
                    result_list = ast.literal_eval(result)

                # Vérifier si result_list est une liste de tuples
                if isinstance(result_list, list) and len(result_list) > 0 and isinstance(result_list[0], tuple):
                    # Extraire la première valeur (si c'est une liste de tuples, on prend le premier tuple et le premier élément)
                    sql_result = result_list[0][0]
                else:
                    sql_result = str(result_list)  # Ou une autre gestion si ce n'est pas une liste de tuples

            except (SyntaxError, ValueError) as e:
                # Gestion des erreurs lors de la conversion
                print(f"Erreur lors de la conversion de la chaîne 'result': {e}")
                sql_result = "Erreur lors de l'analyse du résultat SQL"

            with open(filename, 'a', newline='', encoding='utf-8') as file:
                writer = csv.writer(file)
                # Enregistrer l'input SQL et le résultat dans le fichier CSV
                writer.writerow([input, sql_query,sql_result, output, result_agent])
                
    return logs