nlp-back / app.py
Yameogo123's picture
Update app.py
196244e verified
from controller.Rag import RAG
from controller.controller import *
from entity.entity import *
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter
import os
import wikipediaapi
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from transformers import VitsModel, AutoTokenizer
import torch
from IPython.display import Audio
from fastapi.responses import FileResponse
import warnings
warnings.filterwarnings(action="ignore")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache'
embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2")
persist_directory = "chroma"
db3 = Chroma(persist_directory=persist_directory, embedding_function=embeddings_model)
retriever = db3.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 5, # Le nombre de phrases à faire remonter
"score_threshold": 0.01 # Le seuil
}
)
model = VitsModel.from_pretrained("facebook/mms-tts-fra")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-fra")
@app.get("/")
async def read_root():
return {"Message": "Bienvenu sur notre API AI."}
@app.put("/add/file")
async def update_db(file: UploadFile = File(...)):
try:
contents = await file.read()
root= f"./{file.filename}"
with open(root, 'wb') as f:
f.write(contents)
ext= file.filename.split(".")[-1]
if ext=="pdf":
loader= PyPDFLoader(root)
elif ext=="txt":
loader= TextLoader(root)
else:
return {"error": "please send pdf or txt type files."}
f.close()
raw_documents = loader.load()
print(raw_documents)
text_splitter = CharacterTextSplitter(
separator="\n\n", chunk_size=400, chunk_overlap=100, length_function=len,
)
documents = text_splitter.split_documents(raw_documents)
db3.add_documents([documents])
db3.persist()
os.remove(root)
return {"message": "base mise à jour avec le nouveau document"}
except Exception as e:
return {"message": "Erreur d'ajout du fichier dans la base."}
@app.post("/wiki")
async def key_words(message:Message):
wiki_wiki = wikipediaapi.Wikipedia('Psychologie-API (test@example.com)', 'fr')
# words= ["Albert Einstein", "Psychanalyse", "Sigmund Freud", "Jacques Lacan", "Thérapie"]
# mess = embeddings_model.embed_query(message.message)
# embedding = [embeddings_model.embed_query(x) for x in words]
# simi = np.array([np.dot(mess, emb) for emb in embedding])
# idx= simi.argmax()
mess= remove_punctuation(message.message)
page_py = wiki_wiki.page(mess)
res= "Wikipédia n'a rien trouvé d'intéressant qui corresponde à cette page." if not page_py.exists() else page_py.title+"\n "+ page_py.summary
return {"message": res}
@app.post("/message")
async def message_to_chatbot(message: Message):
chatbot= RAG(retriever, "mistralai/Mixtral-8x7B-Instruct-v0.1")
response= chatbot.generate_response(message.message)
return {"message": response}
@app.post("/match")
async def get_doc(message:Message):
try:
docs = db3.similarity_search(message.message, k=3)
return {"message": docs[0].page_content}
except Exception:
return {"message": "no available document"}
@app.post("/text_to_speech")
async def text_to_speech(message: Message):
inputs = tokenizer(message.message, return_tensors="pt")
with torch.no_grad():
output = model(**inputs).waveform
audio= Audio(output, rate=model.config.sampling_rate)
with open('./audio.wav', 'wb') as f:
f.write(audio.data)
f.close()
# fich= open("./audio.wav", "rb")
# res= fich.read()
# print(res)
return {"message": "File Saved"}
@app.get("/send/audio")
async def send_audio():
return FileResponse("./audio.wav", media_type="audio/wav")