Spaces:
Sleeping
Sleeping
from fastapi import BackgroundTasks, APIRouter, Request, UploadFile, Depends, File | |
from typing import List | |
from src.helpers.response import ResponseHandler | |
from src.helpers.json_convertor import convert_to_json | |
from src.middleware.verifyToken import verify_token | |
from src.services.message_services import save_question_and_answer_to_chat_history | |
from src.bot.main import Main | |
from src.bot.OtherFun import delete_chain_after_delay, process_file | |
router = APIRouter() | |
model = Main() | |
async def askQ(req: Request, token: str = Depends(verify_token)): | |
try: | |
form_data = await req.json() | |
answer = model.ask_question( | |
form_data["query"], | |
( | |
token["username"] | |
if form_data["collectionName"] is None | |
else form_data["collectionName"] | |
), | |
) | |
await save_question_and_answer_to_chat_history( | |
token["username"], | |
{ | |
"question": form_data["query"], | |
"answer": answer, | |
"collectionName": ( | |
form_data["query"] if form_data["query"] else "CHAT WITH YOUR PDF" | |
), | |
}, | |
) | |
return ResponseHandler.success(2000, answer) | |
except Exception as error: | |
return ResponseHandler.error(9999, error, 500) | |
async def createEmbedding( | |
collection_name: str, | |
files: List[UploadFile] = File(None), | |
tokenData: str = Depends(verify_token), | |
): | |
try: | |
if not files: | |
return ResponseHandler.error(2003) | |
responses = [] | |
for file in files: | |
response = process_file(model, collection_name, file) | |
responses.append(response) | |
return ResponseHandler.success(2002, response) | |
except Exception as error: | |
return ResponseHandler.error(9999, error, 500) | |
async def createTmpChain( | |
background_tasks: BackgroundTasks, | |
files: List[UploadFile] = File(...), | |
tokenData: str = Depends(verify_token), | |
): | |
try: | |
if not files: | |
return ResponseHandler.error(2003, error, 500) | |
all_contents = b"" | |
for file in files: | |
contents = await file.read() | |
all_contents += contents | |
file_extension = files[0].filename.split(".")[-1] | |
if file_extension == "pdf": | |
chain_name = tokenData["username"] | |
model.generate_tmp_embedding_and_chain(all_contents, chain_name) | |
background_tasks.add_task(delete_chain_after_delay(model, chain_name)) | |
return ResponseHandler.success(2001) | |
elif file_extension == "txt": | |
all_contents.decode("utf-8") | |
return ResponseHandler.error(2004, error, 500) | |
except Exception as error: | |
return ResponseHandler.error(9999, error, 500) | |