File size: 4,263 Bytes
ad1da67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196244e
ad1da67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

from controller.Rag import RAG
from controller.controller import *
from entity.entity import *
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import CharacterTextSplitter
import os
import wikipediaapi
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from transformers import VitsModel, AutoTokenizer
import torch
from IPython.display import Audio

from fastapi.responses import FileResponse
import warnings

warnings.filterwarnings(action="ignore")


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache'

embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L12-v2")
persist_directory = "chroma"
db3 = Chroma(persist_directory=persist_directory, embedding_function=embeddings_model)

retriever = db3.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        "k": 5, # Le nombre de phrases à faire remonter
        "score_threshold": 0.01 # Le seuil
    }
)

model = VitsModel.from_pretrained("facebook/mms-tts-fra")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-fra")


@app.get("/")
async def read_root():
    return {"Message": "Bienvenu sur notre API AI."}


@app.put("/add/file")
async def update_db(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        root= f"./{file.filename}"
        with open(root, 'wb') as f:
            f.write(contents)
        ext= file.filename.split(".")[-1]
        if ext=="pdf":
            loader= PyPDFLoader(root)
        elif ext=="txt":
            loader= TextLoader(root)
        else:
            return {"error": "please send pdf or txt type files."}
        f.close()
        raw_documents = loader.load()
        print(raw_documents)
        text_splitter = CharacterTextSplitter(
            separator="\n\n", chunk_size=400, chunk_overlap=100, length_function=len,
        )
        documents = text_splitter.split_documents(raw_documents)
        db3.add_documents([documents])
        db3.persist()
        os.remove(root)
        return {"message": "base mise à jour avec le nouveau document"}
    except Exception as e:
        return {"message": "Erreur d'ajout du fichier dans la base."}


@app.post("/wiki")
async def key_words(message:Message):
    wiki_wiki = wikipediaapi.Wikipedia('Psychologie-API (test@example.com)', 'fr')
    # words= ["Albert Einstein", "Psychanalyse", "Sigmund Freud", "Jacques Lacan", "Thérapie"]
    # mess = embeddings_model.embed_query(message.message)
    # embedding = [embeddings_model.embed_query(x) for x in words]
    # simi = np.array([np.dot(mess, emb) for emb in embedding])
    # idx= simi.argmax()
    mess= remove_punctuation(message.message)
    page_py = wiki_wiki.page(mess)
    res= "Wikipédia n'a rien trouvé d'intéressant qui corresponde à cette page." if not page_py.exists() else page_py.title+"\n "+ page_py.summary
    return {"message": res}

@app.post("/message")
async def message_to_chatbot(message: Message):
    chatbot= RAG(retriever, "mistralai/Mixtral-8x7B-Instruct-v0.1")
    response= chatbot.generate_response(message.message)
    return {"message": response}


@app.post("/match")
async def get_doc(message:Message):
    try:
        docs = db3.similarity_search(message.message, k=3)
        return {"message": docs[0].page_content}
    except Exception:
        return {"message": "no available document"}
    


@app.post("/text_to_speech")
async def text_to_speech(message: Message):
    inputs = tokenizer(message.message, return_tensors="pt")
    with torch.no_grad():
        output = model(**inputs).waveform
    audio= Audio(output, rate=model.config.sampling_rate)
    with open('./audio.wav', 'wb') as f:
        f.write(audio.data)
    f.close()
    # fich= open("./audio.wav", "rb")
    # res= fich.read()
    # print(res)
    return {"message": "File Saved"}


@app.get("/send/audio")
async def send_audio():
    return FileResponse("./audio.wav", media_type="audio/wav")