LCA's picture
Add stop condition on chat generation
38f1367 verified
import os
import sys
import pandas as pd
import numpy as np
import faiss
import gradio as gr
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
from datasets import load_dataset
import json
import re
DATASET_REPO = "LCA/HACKATHON_PARTS"
dataset = load_dataset(DATASET_REPO, split="train")
df = dataset.to_pandas()
descriptions = df['DESIGNATION'].tolist()
codes = df["CODE"].astype(str).tolist()
# --- Embedding model ---
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
#--- Load or compute embeddings + FAISS index ---
#For start, test perf without caching this
if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"):
embeddings = np.load("embeddings.npy")
index = faiss.read_index("faiss.index")
else:
embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
# Save embeddings and index for future use
np.save("embeddings.npy", embeddings)
faiss.write_index(index, "faiss.index")
# --- Inference API client ---
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN"))
def rechercher_article(articleSource):
print(f"Recherch article pour {articleSource}")
article = {}
source = articleSource["designation"]
query_embedding = embedding_model.encode([source], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)
# Recherche du/des voisin(s) le(s) plus proche(s)
similarity_scores, indices = index.search(query_embedding, k=1)
# Gérer la qualité du retour avec un seuil de similarité
threshold = 0.7 # à ajuster selon vos tests
print(f"Score de similarité ({similarity_scores[0][0]:.2f}) pour '{source}'")
if similarity_scores[0][0] < threshold:
article["code"] = "Inconnu"
article["designation"] = source
article["source"] = source
article["quantite"] = articleSource.get("quantite", None)
print(f"Code non trouvé pour '{source}'")
else:
article["code"] = codes[indices[0][0]]
article["designation"] = descriptions[indices[0][0]]
article["source"] = source
article["quantite"] = articleSource.get("quantite", None)
print(f"Code trouvé pour '{source}': {article['code']} / {article['designation']}")
return article
def extract_json_from_response(response):
"""
Extrait le premier bloc JSON valide d'une chaîne de texte contenant potentiellement du texte en vrac.
Gère les dialogues USER/INST et autres artefacts de modèles de chat.
Retourne un objet Python (dict) ou None si extraction impossible.
"""
# Nettoyer la réponse des balises de dialogue communes
cleaned_response = response
# Supprimer les balises de dialogue courantes
patterns_to_remove = [
r'USER:.*?(?=\{|$)',
r'INST:.*?(?=\{|$)',
r'ASSISTANT:.*?(?=\{|$)',
r'AI:.*?(?=\{|$)',
r'```json',
r'```',
r'Here is the JSON:',
r'The JSON response is:',
r'Response:',
]
for pattern in patterns_to_remove:
cleaned_response = re.sub(pattern, '', cleaned_response, flags=re.IGNORECASE | re.DOTALL)
# Recherche tous les blocs JSON potentiels dans la réponse nettoyée
json_candidates = re.findall(r'({[\s\S]*?})', cleaned_response)
for candidate in json_candidates:
try:
# Nettoyer le candidat des caractères parasites
candidate = candidate.strip()
parsed = json.loads(candidate)
# Vérifier que c'est un objet avec la structure attendue
if isinstance(parsed, dict):
return parsed
except Exception:
continue
# Si aucun bloc JSON valide trouvé, essayer de corriger les crochets manquants
try:
start = cleaned_response.index('{')
end = cleaned_response.rindex('}') + 1
json_str = cleaned_response[start:end]
return json.loads(json_str)
except Exception as e:
print("Erreur lors du parsing JSON extrait:", e)
print("Réponse brute:", response)
print("Réponse nettoyée:", cleaned_response)
return None
def respond(message):
print(" ------------------ ")
print(message)
print(" ------------------ ")
# Prompt par défaut
custom_prompt = """Tu es un analyseur de texte qui extrait des informations d'articles.
Tu dois analyser le message et identifier les articles demandés avec leurs quantités.
IMPORTANT: Réponds UNIQUEMENT avec un objet JSON valide, sans texte supplémentaire.
Format de réponse attendu:
{
"articles": [
{
"designation": "description de l'article",
"quantite": nombre_ou_null
}
]
}
Règles:
- Pas de texte avant ou après le JSON
- Pas de commentaires
- Pas de dialogue USER/INST
- Juste le JSON brut
"""
messages = [{"role": "system", "content": custom_prompt}]
messages += [{"role": "user", "content": message}]
# Utiliser zephyr avec des paramètres plus stricts pour éviter les dialogues
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN"))
# client = InferenceClient(
# "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
# token=os.getenv("HF_TOKEN"),
# #provider="auto" # or choose a supported provider from the error message
# )
full_response = ""
for chunk in client.chat_completion(
messages,
max_tokens=256, # Réduire pour éviter les dialogues longs
stream=True,
temperature=0.05, # Très faible pour plus de déterminisme
top_p=0.3, # Plus restrictif
stop=["\n\n", "USER:", "Assistant:", "###"]
):
token = chunk.choices[0].delta.content
if token:
full_response += token
# yield full_response.replace("\n", "\n\n")
print("---- retour de l'analyse")
print(full_response)
print("--")
json_response = extract_json_from_response(full_response)
print(json_response)
# If you expect a JSON response, you can try to parse it here
# import json
# try:
order = {}
try:
if json_response is None:
print("Aucun JSON valide trouvé dans la réponse")
return {"articles": [], "erreur": "Impossible de parser la réponse"}
articles = []
# Vérifier si la réponse a la structure attendue
if "articles" in json_response:
articles_data = json_response["articles"]
else:
# Si pas de clé "articles", essayer d'utiliser la réponse directement si c'est une liste
if isinstance(json_response, list):
articles_data = json_response
else:
print("Structure JSON inattendue:", json_response)
return {"articles": [], "erreur": "Structure JSON inattendue"}
for article in articles_data:
if isinstance(article, dict) and "designation" in article:
found_article = rechercher_article(article)
articles.append(found_article)
else:
print("Article mal formaté:", article)
order["articles"] = articles
# Ajouter les champs destinataire et delai avec des valeurs figées
order["destinataire"] = {
"societe": "Société Exemple",
"nom": "Dupont",
"prenom": "Jean",
"email": "jean.dupont@exemple.com"
}
order["delai"] = "2024-07-15"
except Exception as e:
print("Could not parse articles:", e)
order = {}
return order
with gr.Blocks() as demo:
gr.Markdown("# Part identification Assistant")
#prompt_box = gr.Textbox(label="Prompt système", value=DEFAULT_PROMPT, lines=8)
#temperature_slider = gr.Slider(label="Température", minimum=0.0, maximum=1.0, value=0.1, step=0.01)
#top_p_slider = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
message_box = gr.Textbox(label="Votre question")
response_box = gr.Textbox(label="Réponse de l'assistant", interactive=False, lines=30)
send_btn = gr.Button("Envoyer")
def chat(message):
history = [] # ou récupère l'historique si tu veux le gérer
gen = respond(message)
return json.dumps(gen, indent=2, ensure_ascii=False)
send_btn.click(
chat,
inputs=[message_box],
outputs=[response_box]
)
if __name__ == "__main__":
demo.launch(share=True)