Spaces:
Running
Running
File size: 3,526 Bytes
0d67dc2 cd842ff 0d67dc2 5ad3bc3 4228071 0d67dc2 cd842ff 0d67dc2 8bd4741 0472301 0d67dc2 07d7cbc 44cc67b 5deafe7 44cc67b cc30c57 44cc67b e8efed2 44cc67b 07d7cbc 44cc67b 5deafe7 44cc67b 07d7cbc 26b05fb 716d802 26b05fb 31bf9c0 5deafe7 07d7cbc 716d802 26b05fb 716d802 0d67dc2 cd842ff 0d67dc2 cd842ff 7229992 cd842ff 0d67dc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import fastapi
import json
import markdown
import uvicorn
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from ctransformers import AutoModelForCausalLM
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
llm = AutoModelForCausalLM.from_pretrained('TheBloke/MPT-7B-Storywriter-GGML',
model_file='mpt-7b-storywriter.ggmlv3.q4_0.bin',
model_type='mpt')
app = fastapi.FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def index():
with open("README.md", "r", encoding="utf-8") as readme_file:
md_template_string = readme_file.read()
html_content = markdown.markdown(md_template_string)
return HTMLResponse(content=html_content, status_code=200)
class ChatCompletionRequest(BaseModel):
prompt: str
@app.get("/demo")
async def demo():
html_content = """
<!DOCTYPE html>
<html>
<head>
<style>
body {
align-items: center;
background-color: #d9b99b;
display: flex;
height: 100vh;
justify-content: center;
}
#content {
align-items: center;
background-color: #fff0db;
box-shadow:
12px 12px 16px 0 rgba(0, 0, 0, 0.25),
-8px -8px 12px 0 rgba(255, 255, 255, 0.3);
border-radius: 50px;
display: flex;
padding: 50px;
justify-content: center;
margin-right: 4rem;
font-size: 16px;
}
</style>
</head>
<body>
<div id="content"></div>
<script>
var source = new EventSource("https://matthoffner-storywriter.hf.space/stream");
source.onmessage = function(event) {
document.getElementById("content").innerHTML += event.data
};
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=200)
@app.get("/v1")
async def flow(prompt = ""):
completion = llm(prompt)
async def server_sent_events(chat_chunks):
yield prompt
for chat_chunk in chat_chunks:
yield chat_chunk
yield ""
return EventSourceResponse(server_sent_events(completion))
@app.get("/stream")
async def chat(prompt = "Once upon a time there was a "):
tokens = llm.tokenize(prompt)
async def server_sent_events(chat_chunks, llm):
yield prompt
for chat_chunk in llm.generate(chat_chunks):
yield llm.detokenize(chat_chunk)
yield ""
return EventSourceResponse(server_sent_events(tokens, llm))
@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest, response_mode=None):
completion = llm(request.prompt)
async def server_sent_events(
chat_chunks,
):
for chat_chunk in chat_chunks:
yield dict(data=json.dumps(chat_chunk))
yield dict(data="[DONE]")
chunks = completion
return EventSourceResponse(
server_sent_events(chunks),
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|