Spaces:
Runtime error
Runtime error
from controller.Rag import RAG | |
from controller.controller import * | |
from entity.entity import * | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.document_loaders import PyPDFLoader, TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
import os | |
import wikipediaapi | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import VitsModel, AutoTokenizer | |
import torch | |
from IPython.display import Audio | |
from fastapi.responses import FileResponse | |
import warnings | |
warnings.filterwarnings(action="ignore") | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache' | |
embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2") | |
persist_directory = "chroma" | |
db3 = Chroma(persist_directory=persist_directory, embedding_function=embeddings_model) | |
retriever = db3.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={ | |
"k": 5, # Le nombre de phrases à faire remonter | |
"score_threshold": 0.01 # Le seuil | |
} | |
) | |
model = VitsModel.from_pretrained("facebook/mms-tts-fra") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-fra") | |
async def read_root(): | |
return {"Message": "Bienvenu sur notre API AI."} | |
async def update_db(file: UploadFile = File(...)): | |
try: | |
contents = await file.read() | |
root= f"./{file.filename}" | |
with open(root, 'wb') as f: | |
f.write(contents) | |
ext= file.filename.split(".")[-1] | |
if ext=="pdf": | |
loader= PyPDFLoader(root) | |
elif ext=="txt": | |
loader= TextLoader(root) | |
else: | |
return {"error": "please send pdf or txt type files."} | |
f.close() | |
raw_documents = loader.load() | |
print(raw_documents) | |
text_splitter = CharacterTextSplitter( | |
separator="\n\n", chunk_size=400, chunk_overlap=100, length_function=len, | |
) | |
documents = text_splitter.split_documents(raw_documents) | |
db3.add_documents([documents]) | |
db3.persist() | |
os.remove(root) | |
return {"message": "base mise à jour avec le nouveau document"} | |
except Exception as e: | |
return {"message": "Erreur d'ajout du fichier dans la base."} | |
async def key_words(message:Message): | |
wiki_wiki = wikipediaapi.Wikipedia('Psychologie-API (test@example.com)', 'fr') | |
# words= ["Albert Einstein", "Psychanalyse", "Sigmund Freud", "Jacques Lacan", "Thérapie"] | |
# mess = embeddings_model.embed_query(message.message) | |
# embedding = [embeddings_model.embed_query(x) for x in words] | |
# simi = np.array([np.dot(mess, emb) for emb in embedding]) | |
# idx= simi.argmax() | |
mess= remove_punctuation(message.message) | |
page_py = wiki_wiki.page(mess) | |
res= "Wikipédia n'a rien trouvé d'intéressant qui corresponde à cette page." if not page_py.exists() else page_py.title+"\n "+ page_py.summary | |
return {"message": res} | |
async def message_to_chatbot(message: Message): | |
chatbot= RAG(retriever, "mistralai/Mixtral-8x7B-Instruct-v0.1") | |
response= chatbot.generate_response(message.message) | |
return {"message": response} | |
async def get_doc(message:Message): | |
try: | |
docs = db3.similarity_search(message.message, k=3) | |
return {"message": docs[0].page_content} | |
except Exception: | |
return {"message": "no available document"} | |
async def text_to_speech(message: Message): | |
inputs = tokenizer(message.message, return_tensors="pt") | |
with torch.no_grad(): | |
output = model(**inputs).waveform | |
audio= Audio(output, rate=model.config.sampling_rate) | |
with open('./audio.wav', 'wb') as f: | |
f.write(audio.data) | |
f.close() | |
# fich= open("./audio.wav", "rb") | |
# res= fich.read() | |
# print(res) | |
return {"message": "File Saved"} | |
async def send_audio(): | |
return FileResponse("./audio.wav", media_type="audio/wav") | |