File size: 3,553 Bytes
4962437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os

import re
from multiprocessing import Process
from tempfile import NamedTemporaryFile
from typing import List, TypedDict

import uvicorn
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

from api.olds.container import agent_manager, file_handler, reload_dirs, templates, uploader
from api.olds.worker import get_task_result, start_worker, task_execute
# from env import settings

app = FastAPI()

app.mount("/static", StaticFiles(directory=uploader.path), name="static")


class ExecuteRequest(BaseModel):
    session: str
    prompt: str
    files: List[str]


class ExecuteResponse(TypedDict):
    answer: str
    files: List[str]


@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.get("/dashboard", response_class=HTMLResponse)
async def dashboard(request: Request):
    return templates.TemplateResponse("dashboard.html", {"request": request})


@app.post("/upload")
async def create_upload_file(files: List[UploadFile]):
    urls = []
    for file in files:
        extension = "." + file.filename.split(".")[-1]
        with NamedTemporaryFile(suffix=extension) as tmp_file:
            tmp_file.write(file.file.read())
            tmp_file.flush()
            urls.append(uploader.upload(tmp_file.name))
    return {"urls": urls}


@app.post("/api/execute")
async def execute(request: ExecuteRequest) -> ExecuteResponse:
    query = request.prompt
    files = request.files
    session = request.session

    executor = agent_manager.create_executor(session)

    promptedQuery = "\n".join([file_handler.handle(file) for file in files])
    promptedQuery += query

    try:
        res = executor({"input": promptedQuery})
    except Exception as e:
        return {"answer": str(e), "files": []}

    files = re.findall(r"\[file://\S*\]", res["output"])
    files = [file[1:-1].split("file://")[1] for file in files]

    return {
        "answer": res["output"],
        "files": [uploader.upload(file) for file in files],
    }


@app.post("/api/execute/async")
async def execute_async(request: ExecuteRequest):
    query = request.prompt
    files = request.files
    session = request.session

    promptedQuery = "\n".join([file_handler.handle(file) for file in files])
    promptedQuery += query

    execution = task_execute.delay(session, promptedQuery)
    return {"id": execution.id}


@app.get("/api/execute/async/{execution_id}")
async def execute_async(execution_id: str):
    execution = get_task_result(execution_id)

    result = {}
    if execution.status == "SUCCESS" and execution.result:
        output = execution.result.get("output", "")
        files = re.findall(r"\[file://\S*\]", output)
        files = [file[1:-1].split("file://")[1] for file in files]
        result = {
            "answer": output,
            "files": [uploader.upload(file) for file in files],
        }

    return {
        "status": execution.status,
        "info": execution.info,
        "result": result,
    }


def serve():
    p = Process(target=start_worker, args=[])
    p.start()
    uvicorn.run("api.main:app", host="0.0.0.0", port=os.environ["EVAL_PORT"])


def dev():
    p = Process(target=start_worker, args=[])
    p.start()
    uvicorn.run(
        "api.main:app",
        host="0.0.0.0",
        port=os.environ["EVAL_PORT"],
        reload=True,
        reload_dirs=reload_dirs,
    )