Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Sun Oct 13 10:30:56 2024 | |
@author: legalchain | |
""" | |
from typing import Literal, Optional, List, Union, Any | |
from langchain_openai import ChatOpenAI | |
import pandas as pd | |
from langchain_core.prompts import ChatPromptTemplate | |
from langgraph.graph import END, StateGraph, START | |
from langchain_core.output_parsers import StrOutputParser | |
from pydantic import BaseModel, Field | |
from models import NatureJugement | |
from prompts import df_prompt, feed_back_prompt, reflection_prompt | |
llm = ChatOpenAI(model="gpt-4o-mini") | |
MAX_GENERATIONS = 2 | |
MAX_ROWS: int = 10 | |
class Query(BaseModel): | |
query:str = Field(..., title="Requête pour filtrer les résultats du dataframe entourée avec des gullemets de type \" ") | |
def clean_query(self): | |
# Correction des échappements dans la chaîne de la requête | |
corrected_query = self.query.replace("\\'", "\\'") | |
# Extraire la condition à l'intérieur des crochets | |
import re | |
condition = re.search(r"df\[(.*)\]", corrected_query).group(1) | |
return condition | |
class GradeResults(BaseModel): | |
binary_score: Literal["yes", "no"] = Field( | |
description="Les résultats sont satisfaisants -> 'yes' ou il y une erreur ou pas de résultats ou les résultats sont améliorables -> 'no'" | |
) | |
class GraphState(BaseModel): | |
df : Any | |
df_head:str | |
instructions: Optional[str] = None | |
nature_jugement: List = ', '.join([e.value for e in NatureJugement]) | |
region:str = '' | |
dep:str = '' | |
query: Optional[str] = None | |
results :Union[str, List[str]] = [] | |
query_feedbacks: Optional[str] = None | |
results_feedbacks: bool = None | |
generation_num: int = 0 | |
retrieval_num: int = 0 | |
search_mode: Literal["vectorstore", "websearch", "QA_LM"] = "QA_LM" | |
error_query: Optional[Any] = "" | |
error_results: Optional[Any] = "" | |
truncated: bool = False | |
# Méthode pour récupérer le DataFrame | |
def get_df(self) -> pd.DataFrame: | |
return pd.read_json(self.df) | |
# Surcharger l'initialisation pour créer les champs 'region' et 'dep' | |
def __init__(self, **data): | |
super().__init__(**data) | |
# Générer les chaînes pour les régions et départements | |
distinct_regions = self.df['region_nom_officiel'].dropna().unique().tolist() | |
distinct_departements = self.df['departement_nom_officiel'].dropna().unique().tolist() | |
# Convertir en chaînes séparées par des virgules | |
self.region = ', '.join(distinct_regions) | |
self.dep = ', '.join(distinct_departements) | |
def generate_query_node(state: GraphState): | |
prompt = ChatPromptTemplate.from_messages(messages = df_prompt) | |
generate_df_query = prompt | llm.with_structured_output( | |
Query, | |
include_raw=True, # permet de checker les erreurs en sortie | |
) | |
# TODO : Ajouter le retour erreur de parse_error | |
try : | |
query_generate = generate_df_query.invoke({ | |
'df_head' : state.df_head, | |
'instructions' : state.instructions, | |
'feedback' : state.query_feedbacks, | |
'error' : state.error_query, | |
'nature_jugement' : state.nature_jugement, | |
'dep' : state.dep, | |
'region': state.region | |
}) | |
query_final = query_generate['parsed'].clean_query() | |
return { | |
"query": query_final, | |
"error_query" : "" # si il ya une erreur cela remet le compteur à zéro | |
} | |
except Exception as e: | |
return {'error_query' : e} | |
def evaluate_query_node(state:GraphState): | |
if state.error_query != "": | |
return "Il y a une erreur dans la requête. Je me suis sûrement trompé. Veuillez réessayer." | |
else: | |
return "ok" | |
def generate_results_node(state:GraphState): | |
try : | |
query = state.query | |
print("query ", query) | |
print('je suis dans generate', type(state.df)) | |
query = eval(query, {"df": state.df}) | |
new_df = state.df[query] | |
print("new_df", new_df.empty) | |
if new_df.empty: | |
return { | |
"generation_num": state.generation_num + 1} | |
elif len(new_df)> MAX_ROWS: | |
return {'results' : new_df.head(MAX_ROWS).to_json(orient='records'), | |
"generation_num": state.generation_num + 1, | |
"truncated": True | |
} | |
else: | |
return {'results' : new_df.to_json(orient='records'), | |
"generation_num": state.generation_num + 1, | |
} | |
except Exception as e : | |
return {'error_results' : e, | |
"generation_num": state.generation_num + 1} | |
def evaluate_results_node(state:GraphState): | |
prompt_eval = ChatPromptTemplate.from_messages(messages=reflection_prompt) | |
generate_eval = prompt_eval | llm.with_structured_output( | |
GradeResults, | |
include_raw=False, # permet de checker les erreurs en sortie | |
) | |
evaluation = generate_eval.invoke({'df_head' : state.df_head, | |
'results' :state.results, | |
'instructions' : state.instructions}) | |
if state.generation_num > MAX_GENERATIONS: | |
return "max_generation_reached" | |
return evaluation.binary_score | |
def query_feedback_node(state: GraphState): | |
prompt_feed_back = ChatPromptTemplate.from_messages(messages=feed_back_prompt) | |
query_feedback_chain = prompt_feed_back| llm |StrOutputParser() | |
feedback = query_feedback_chain.invoke({ | |
"df_head" : state.df_head, | |
"instructions": state.instructions, | |
"results": state.results, | |
"query": state.query | |
}) | |
feedback = f"Evaluation de la recherche : {feedback}" | |
print(feedback) | |
return {"query_feedbacks": feedback} |