File size: 2,753 Bytes
4eff2f9
 
 
 
 
 
 
6349229
4eff2f9
 
 
6349229
 
 
4eff2f9
 
 
 
 
 
 
 
 
 
 
8822968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fae0620
a29031f
8822968
 
8b2ddf3
8822968
 
 
8b2ddf3
8822968
 
 
 
4eff2f9
 
8822968
 
 
 
 
 
 
 
 
 
 
 
4eff2f9
c08a662
4eff2f9
 
 
c08a662
6349229
 
 
 
4eff2f9
6349229
4eff2f9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import fastapi
import json
import markdown
import uvicorn
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from ctransformers import AutoModelForCausalLM
from ctransformers.langchain import CTransformers
from pydantic import BaseModel

llm = AutoModelForCausalLM.from_pretrained("TheBloke/gorilla-7B-GGML",
                                           model_file="Gorilla-7B.ggmlv3.q4_0.bin",
                                           model_type="llama")
app = fastapi.FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def index():
    html_content = """
    <!DOCTYPE html>
    <html>
        <style>
            body {
                font-family: "Arial";
            }
            h1 {
                text-align: "center";
            }
        </style>
        <body>
            <h1>gorilla</h1>
            <input id="prompt" type="text">
            <button id="search">I'm feeling lucky</button>
            <div id="content"></div>
            <script>
                document.getElementById("search").addEventListener("click", () => {
                    let prompt = document.getElementById("prompt").value;
                    let source = new EventSource(`https://matthoffner-gorilla.hf.space/stream?prompt=${prompt}`);
                    source.onmessage = function(event) {
                        console.log(event);
                        let eventData = event.data;
                        document.getElementById("content").innerHTML += eventData
                    };
                });
            </script>
        </body>
    </html>
    """
    return HTMLResponse(content=html_content, status_code=200)

@app.get("/stream")
async def chat(prompt = "I want to download a dataset from GCS"):
    tokens = llm.tokenize(prompt)
    async def server_sent_events(chat_chunks, llm):
        yield prompt
        for chat_chunk in llm.generate(chat_chunks):
            yield llm.detokenize(chat_chunk)
        yield ""

    return EventSourceResponse(server_sent_events(tokens, llm))


class ChatCompletionRequest(BaseModel):
    messages: str

@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest, response_mode=None):
    tokens = llm.tokenize(request.messages.join(' '))
    async def server_sent_events(chat_chunks, llm):
        for chat_chunk in llm.generate(chat_chunks):
            yield llm.detokenize(chat_chunk)
        yield ""

    return EventSourceResponse(server_sent_events(tokens, llm))

if __name__ == "__main__":
  uvicorn.run(app, host="0.0.0.0", port=8000)