File size: 3,828 Bytes
5e8fd8b
 
73f27dc
608245a
5cd5ef2
4929aba
 
5e8fd8b
 
ab4a5ae
5e8fd8b
 
 
 
 
 
 
 
 
 
 
 
 
9f0a9ca
5e8fd8b
ab4a5ae
4929aba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82c3144
 
d04ea2c
fd6a7b2
 
d04ea2c
9e59b9c
fd6a7b2
 
 
 
 
 
 
 
4929aba
 
 
 
608245a
 
 
 
ce54787
608245a
ce54787
608245a
 
 
ce54787
608245a
 
 
3f857b9
608245a
 
 
 
5e8fd8b
 
 
4c88907
1ff6584
453bbe9
1ff6584
 
9f0a9ca
24992a3
9f0a9ca
2cfa0e6
1ff6584
 
 
453bbe9
 
5e8fd8b
 
 
 
4c88907
 
 
 
 
 
5e8fd8b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import shutil
import json
import asyncio
from datetime import datetime
from typing import List
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from .rag import ChatPDF

middleware = [
    Middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_methods=['*'],
        allow_headers=['*']
    )
]

app = FastAPI(middleware=middleware)

files_dir = os.path.expanduser("~/wtp_be_files/")
session_assistant = ChatPDF()

class ConnectionManager:
    def __init__(self):
        self.active_connections: List[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def send_personal_message(self, message: str, websocket: WebSocket):
        await websocket.send_text(message)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            await connection.send_text(message)

manager = ConnectionManager()

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
    await manager.connect(websocket)
    now = datetime.now()
    current_time = now.strftime("%H:%M")
    try:
        while True:
            data = await websocket.receive_text()
            data = data.strip()
            if len(data) > 0:
                if not session_assistant.pdf_count > 0:
                    message = {"time":current_time,"clientId":client_id,"message":"Please, add a PDF document first."}
                    await manager.send_personal_message(json.dumps(message), websocket)
                else:
                    print("FETCHING STREAM")
                    streaming_response = session_assistant.ask(data)
                    print("STARTING STREAM")
                    for text in streaming_response.response_gen:
                        message = {"time":current_time,"clientId":client_id,"message":text}
                        # await manager.broadcast(json.dumps(message))
                        await manager.send_personal_message(json.dumps(message), websocket)
                    print("ENDING STREAM")
    except WebSocketDisconnect:
        manager.disconnect(websocket)
        message = {"time":current_time,"clientId":client_id,"message":"Offline"}
        await manager.broadcast(json.dumps(message))


async def astreamer(generator):
    try:
        print("streaming........")
        for i in generator:
            print(i)
            yield (i)
            await asyncio.sleep(.1)
    except asyncio.CancelledError as e:
        yield ('cancelled')


@app.get("/query")
async def process_input(text: str):
    if text and len(text.strip()) > 0:
        text = text.strip()
        streaming_response = session_assistant.ask(text)
        return StreamingResponse(astreamer(streaming_response.response_gen), media_type='text/event-stream')


@app.post("/upload")
def upload(files: list[UploadFile]):  
    try:
        os.makedirs(files_dir)
        for file in files:
            try:
                path = f"{files_dir}/{file.filename}"
                file.file.seek(0)
                with open(path, 'wb') as destination:
                    shutil.copyfileobj(file.file, destination)
            finally:
                file.file.close()
    finally:
        session_assistant.ingest(files_dir)
        shutil.rmtree(files_dir)

    return "Files inserted!"


@app.get("/clear")
def ping():
    session_assistant.clear()
    return "All files have been cleared."


@app.get("/")
def ping():
    return "Pong!"