GAIA-Hackathon / chatbot.py
TD9991's picture
second commit
03df0fa
raw
history blame
No virus
3.6 kB
import os
import gradio as gr
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.readers.web import SimpleWebPageReader
from llama_index.llms.mistralai import MistralAI
from llama_index.embeddings.mistralai import MistralAIEmbedding
from llama_index.core import Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from dotenv import load_dotenv
def print_conversation(chat_history):
"""
Fonction de vérification.
Args:
chat_history (List[List[str]]): historique d'échanges avec le chatbot
"""
for question, response in chat_history:
print('Question :', end = '\n')
print(question, end = '\n')
print('Response :', end = '\n')
print(response, end = '\n\n')
def main(
query_engine,
elevage,
temperature,
humidite,
meteo
):
"""
Fonction qui crée le bloc du chatbot avec gradio.
Couplage entre le
Args:
query_engine (_type_): index de recherche vectorielle.
elevage (str):
temperature (str):
humidite (str):
meteo (str):
Returns:
None
"""
title = "Gaia Mistral Chat RAG URL Demo"
description = "Example of an assistant with Gradio, RAG from url and Mistral AI via its API"
placeholder = "Vous pouvez me posez une question sur ce contexte, appuyer sur Entrée pour valider"
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder = placeholder)
clear = gr.ClearButton([msg, chatbot])
def respond(message, chat_history):
global_message = f"""
Tu es un chatbot qui réponds en français, commence et qui dois aider à déterminer les risques parasitaires d'élevage en fonction des conditions météo, voici les conditions météo :
- élevage : {elevage}
- température moyenne : {temperature}
- humidité moyenne : {humidite}
- météo : {meteo}
{message}
Donne-moi ensuite les solutions de prévention et de traitement pour chacun d'eux indépendamment du tableau.
Tu dois impérativement répondre en français.
"""
response = query_engine.query(global_message)
chat_history.append((message, str(response)))
print_conversation(chat_history)
return '', chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
demo.title = title
demo.launch()
if __name__ == "__main__":
#### Loading de la clef Mistral
load_dotenv()
env_api_key = os.getenv('MISTRAL_API_KEY')
### Type du modèle
llm_model = 'mistral-small-2312'
### Config modèle
Settings.llm = MistralAI(
max_tokens = 1024,
api_key = env_api_key
)
Settings.embed_model = MistralAIEmbedding(
model_name="mistral-embed",
api_key=env_api_key
)
### Mise en place de l'index
documents = SimpleDirectoryReader("documents").load_data()
index = VectorStoreIndex.from_documents(documents)
query_engine = index.as_query_engine(similarity_top_k=15)
### Précision du type d'élevage et des conditions météo
elevage = 'bovin'
temperature = '15°C'
humidite = '40%'
meteo = 'pluvieux'
### Lancement du Gradio
main(
query_engine,
elevage,
temperature,
humidite,
meteo
)