|
import hashlib |
|
import os |
|
import time |
|
from fastapi import Body, FastAPI, File, Form, HTTPException, Response, UploadFile |
|
from fastapi.responses import JSONResponse |
|
import uvicorn |
|
|
|
from src import ChatWorld |
|
from src.Response import ChatResponse, FileResponse |
|
from src.logging import logging_info |
|
from src.user import UUID, Role, User |
|
from src.utils import convertToUTF8 |
|
|
|
app = FastAPI() |
|
chatWorld = ChatWorld() |
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
SAVE_DIR = "uploads" |
|
|
|
user_info: dict[UUID, User] = dict() |
|
|
|
|
|
@app.post("/uploadFile", response_model=FileResponse) |
|
def upload_file(uuid: str = Form(), file: UploadFile = File(...)): |
|
|
|
if not os.path.exists(os.path.join(BASE_DIR, SAVE_DIR)): |
|
os.makedirs(os.path.join(BASE_DIR, SAVE_DIR)) |
|
|
|
file_name = f"{time.time_ns()}_{uuid}_{file.filename}" |
|
file_path = os.path.join(BASE_DIR, SAVE_DIR, file_name) |
|
|
|
file_content = file.file.read() |
|
|
|
with open(file_path, "wb") as f: |
|
f.write(file_content) |
|
|
|
file_content_utf8 = convertToUTF8(file_content) |
|
|
|
chatWorld.setStory( |
|
stories=file_content_utf8, |
|
metas={ |
|
"uuid": uuid, |
|
}, |
|
) |
|
|
|
user_info[uuid] = User() |
|
print(user_info) |
|
|
|
role_name_list,role_name_dict = chatWorld.getRoleNameFromFile(file_content_utf8) |
|
return FileResponse( |
|
filename=file_name, |
|
role_name_list=role_name_list, |
|
role_name_dict=role_name_dict, |
|
md5=hashlib.md5(file_content).hexdigest(), |
|
) |
|
|
|
|
|
@app.post("/chatWithCharacter", response_model=ChatResponse) |
|
def chatWithCharacter( |
|
uuid: str = Body(...), |
|
text: str = Body(...), |
|
use_local_model: bool = Body(False), |
|
top_k: int = Body(5), |
|
role_info: Role = Body(...), |
|
): |
|
user = user_info.get(uuid) |
|
|
|
if not user: |
|
raise HTTPException(status_code=400, detail="User not found") |
|
|
|
user_info[uuid] = user.update(role_info.model_dump()) |
|
logging_info(f"user_info: {user_info}") |
|
|
|
response = chatWorld.chatWithCharacter( |
|
text=text, |
|
use_local_model=use_local_model, |
|
top_k=top_k, |
|
**role_info.model_dump(), |
|
metas={"uuid": uuid}, |
|
) |
|
|
|
return ChatResponse(response=response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run("main:app", host="0.0.0.0", port=8000) |
|
|