File size: 3,526 Bytes
0d67dc2
 
 
 
cd842ff
0d67dc2
5ad3bc3
4228071
0d67dc2
cd842ff
0d67dc2
8bd4741
 
0472301
0d67dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07d7cbc
 
 
 
 
 
44cc67b
 
 
 
 
 
 
 
 
5deafe7
44cc67b
 
 
 
 
 
 
cc30c57
44cc67b
 
e8efed2
44cc67b
 
 
07d7cbc
44cc67b
5deafe7
44cc67b
 
 
 
 
 
 
 
 
07d7cbc
 
 
 
 
26b05fb
 
 
 
 
 
 
 
 
 
 
 
716d802
 
26b05fb
 
31bf9c0
5deafe7
 
07d7cbc
716d802
26b05fb
716d802
0d67dc2
 
 
cd842ff
 
 
 
0d67dc2
cd842ff
 
 
7229992
cd842ff
 
 
 
0d67dc2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import fastapi
import json
import markdown
import uvicorn
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from ctransformers import AutoModelForCausalLM
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse

llm = AutoModelForCausalLM.from_pretrained('TheBloke/MPT-7B-Storywriter-GGML',
                                           model_file='mpt-7b-storywriter.ggmlv3.q4_0.bin',
                                           model_type='mpt')
app = fastapi.FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def index():
    with open("README.md", "r", encoding="utf-8") as readme_file:
        md_template_string = readme_file.read()
    html_content = markdown.markdown(md_template_string)
    return HTMLResponse(content=html_content, status_code=200)

class ChatCompletionRequest(BaseModel):
    prompt: str


@app.get("/demo")
async def demo():
    html_content = """
    <!DOCTYPE html>
    <html>
        <head>
          <style>
            body {
              align-items: center;
              background-color: #d9b99b;
              display: flex;
              height: 100vh;
              justify-content: center;
            }
            #content {
              align-items: center;
              background-color: #fff0db;
              box-shadow: 
                12px 12px 16px 0 rgba(0, 0, 0, 0.25),
                -8px -8px 12px 0 rgba(255, 255, 255, 0.3);
              border-radius: 50px;
              display: flex;
              padding: 50px;
              justify-content: center;
              margin-right: 4rem;
              font-size: 16px;
            }
          </style>
        </head>
    
        <body>
            <div id="content"></div>
            
            <script>
              var source = new EventSource("https://matthoffner-storywriter.hf.space/stream");
              source.onmessage = function(event) {
                document.getElementById("content").innerHTML += event.data
              };
            </script>
        
        </body>
    </html>
    """
    return HTMLResponse(content=html_content, status_code=200)


@app.get("/v1")
async def flow(prompt = ""):
    completion = llm(prompt)
    async def server_sent_events(chat_chunks):
        yield prompt
        for chat_chunk in chat_chunks:
            yield chat_chunk
        yield ""

    return EventSourceResponse(server_sent_events(completion))


@app.get("/stream")
async def chat(prompt = "Once upon a time there was a "):
    tokens = llm.tokenize(prompt)
    async def server_sent_events(chat_chunks, llm):
        yield prompt
        for chat_chunk in llm.generate(chat_chunks):
            yield llm.detokenize(chat_chunk)
        yield ""

    return EventSourceResponse(server_sent_events(tokens, llm))

@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest, response_mode=None):
    completion = llm(request.prompt)

    async def server_sent_events(
        chat_chunks,
    ):
        for chat_chunk in chat_chunks:
            yield dict(data=json.dumps(chat_chunk))
        yield dict(data="[DONE]")

    chunks = completion

    return EventSourceResponse(
        server_sent_events(chunks),
    )


if __name__ == "__main__":
  uvicorn.run(app, host="0.0.0.0", port=8000)