Spaces:
Runtime error
Runtime error
| from controller.Rag import RAG | |
| from controller.controller import * | |
| from entity.entity import * | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.document_loaders import PyPDFLoader, TextLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| import os | |
| import wikipediaapi | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import VitsModel, AutoTokenizer | |
| import torch | |
| from IPython.display import Audio | |
| from fastapi.responses import FileResponse | |
| import warnings | |
| warnings.filterwarnings(action="ignore") | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache' | |
| embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2") | |
| persist_directory = "chroma" | |
| db3 = Chroma(persist_directory=persist_directory, embedding_function=embeddings_model) | |
| retriever = db3.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={ | |
| "k": 5, # Le nombre de phrases à faire remonter | |
| "score_threshold": 0.01 # Le seuil | |
| } | |
| ) | |
| model = VitsModel.from_pretrained("facebook/mms-tts-fra") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-fra") | |
| async def read_root(): | |
| return {"Message": "Bienvenu sur notre API AI."} | |
| async def update_db(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| root= f"./{file.filename}" | |
| with open(root, 'wb') as f: | |
| f.write(contents) | |
| ext= file.filename.split(".")[-1] | |
| if ext=="pdf": | |
| loader= PyPDFLoader(root) | |
| elif ext=="txt": | |
| loader= TextLoader(root) | |
| else: | |
| return {"error": "please send pdf or txt type files."} | |
| f.close() | |
| raw_documents = loader.load() | |
| print(raw_documents) | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n\n", chunk_size=400, chunk_overlap=100, length_function=len, | |
| ) | |
| documents = text_splitter.split_documents(raw_documents) | |
| db3.add_documents([documents]) | |
| db3.persist() | |
| os.remove(root) | |
| return {"message": "base mise à jour avec le nouveau document"} | |
| except Exception as e: | |
| return {"message": "Erreur d'ajout du fichier dans la base."} | |
| async def key_words(message:Message): | |
| wiki_wiki = wikipediaapi.Wikipedia('Psychologie-API (test@example.com)', 'fr') | |
| # words= ["Albert Einstein", "Psychanalyse", "Sigmund Freud", "Jacques Lacan", "Thérapie"] | |
| # mess = embeddings_model.embed_query(message.message) | |
| # embedding = [embeddings_model.embed_query(x) for x in words] | |
| # simi = np.array([np.dot(mess, emb) for emb in embedding]) | |
| # idx= simi.argmax() | |
| mess= remove_punctuation(message.message) | |
| page_py = wiki_wiki.page(mess) | |
| res= "Wikipédia n'a rien trouvé d'intéressant qui corresponde à cette page." if not page_py.exists() else page_py.title+"\n "+ page_py.summary | |
| return {"message": res} | |
| async def message_to_chatbot(message: Message): | |
| chatbot= RAG(retriever, "mistralai/Mixtral-8x7B-Instruct-v0.1") | |
| response= chatbot.generate_response(message.message) | |
| return {"message": response} | |
| async def get_doc(message:Message): | |
| try: | |
| docs = db3.similarity_search(message.message, k=3) | |
| return {"message": docs[0].page_content} | |
| except Exception: | |
| return {"message": "no available document"} | |
| async def text_to_speech(message: Message): | |
| inputs = tokenizer(message.message, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model(**inputs).waveform | |
| audio= Audio(output, rate=model.config.sampling_rate) | |
| with open('./audio.wav', 'wb') as f: | |
| f.write(audio.data) | |
| f.close() | |
| # fich= open("./audio.wav", "rb") | |
| # res= fich.read() | |
| # print(res) | |
| return {"message": "File Saved"} | |
| async def send_audio(): | |
| return FileResponse("./audio.wav", media_type="audio/wav") | |