File size: 2,455 Bytes
6f179e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import hashlib
import os
import time
from fastapi import Body, FastAPI, File, Form, HTTPException, Response, UploadFile
from fastapi.responses import JSONResponse
import uvicorn

from src import ChatWorld
from src.Response import ChatResponse, FileResponse
from src.logging import logging_info
from src.user import UUID, Role, User
from src.utils import convertToUTF8

app = FastAPI()
chatWorld = ChatWorld()

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
SAVE_DIR = "uploads"

user_info: dict[UUID, User] = dict()


@app.post("/uploadFile", response_model=FileResponse)
def upload_file(uuid: str = Form(), file: UploadFile = File(...)):

    if not os.path.exists(os.path.join(BASE_DIR, SAVE_DIR)):
        os.makedirs(os.path.join(BASE_DIR, SAVE_DIR))

    file_name = f"{time.time_ns()}_{uuid}_{file.filename}"
    file_path = os.path.join(BASE_DIR, SAVE_DIR, file_name)

    file_content = file.file.read()

    with open(file_path, "wb") as f:
        f.write(file_content)

    file_content_utf8 = convertToUTF8(file_content)

    chatWorld.setStory(
        stories=file_content_utf8,
        metas={
            "uuid": uuid,
        },
    )

    user_info[uuid] = User()
    print(user_info)

    role_name_list,role_name_dict = chatWorld.getRoleNameFromFile(file_content_utf8)
    return FileResponse(
        filename=file_name,
        role_name_list=role_name_list,
        role_name_dict=role_name_dict,
        md5=hashlib.md5(file_content).hexdigest(),
    )


@app.post("/chatWithCharacter", response_model=ChatResponse)
def chatWithCharacter(
    uuid: str = Body(...),
    text: str = Body(...),
    use_local_model: bool = Body(False),
    top_k: int = Body(5),
    role_info: Role = Body(...),
):
    user = user_info.get(uuid)

    if not user:
        raise HTTPException(status_code=400, detail="User not found")

    user_info[uuid] = user.update(role_info.model_dump())
    logging_info(f"user_info: {user_info}")

    response = chatWorld.chatWithCharacter(
        text=text,
        use_local_model=use_local_model,
        top_k=top_k,
        **role_info.model_dump(),
        metas={"uuid": uuid},
    )

    return ChatResponse(response=response)


# @app.post("/chatWithoutCharacter")
# def chatWithoutCharacter(
#     uuid: str = Body(...),
#     text: str = Body(...),
#     use_local_model: bool = Body(...),
# ):
#     pass


if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000)