matthoffner commited on
Commit
1044c29
1 Parent(s): 5d8b6f6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -8
main.py CHANGED
@@ -2,14 +2,19 @@ import fastapi
2
  import json
3
  import markdown
4
  import uvicorn
 
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
  from ctransformers.langchain import CTransformers
9
  from pydantic import BaseModel
 
10
 
11
- llm = CTransformers(model='ggml-model-q4_1.bin', model_type='starcoder')
12
- app = fastapi.FastAPI()
 
 
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -26,17 +31,16 @@ async def index():
26
  return HTMLResponse(content=html_content, status_code=200)
27
 
28
  class ChatCompletionRequest(BaseModel):
29
- prompt: str
30
 
31
  @app.post("/v1/chat/completions")
32
  async def chat(request: ChatCompletionRequest, response_mode=None):
33
- completion = llm(request.prompt)
34
  async def server_sent_events(chat_chunks):
35
- for chat_chunk in chat_chunks:
36
- yield dict(data=json.dumps(chat_chunk))
37
- yield dict(data="[DONE]")
38
 
39
- return EventSourceResponse(server_sent_events(completion))
40
 
41
  if __name__ == "__main__":
42
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import json
3
  import markdown
4
  import uvicorn
5
+ from ctransformers import AutoModelForCausalLM
6
  from fastapi.responses import HTMLResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from sse_starlette.sse import EventSourceResponse
9
  from ctransformers.langchain import CTransformers
10
  from pydantic import BaseModel
11
+ from typing import List, Any
12
 
13
+ llm = AutoModelForCausalLM.from_pretrained("starchat-alpha-GGML",
14
+ model_file="starchat-alpha-ggml-q4_0.bin",
15
+ model_type="starcoder")
16
+
17
+ app = fastapi.FastAPI(title="Starchat Alpha")
18
  app.add_middleware(
19
  CORSMiddleware,
20
  allow_origins=["*"],
 
31
  return HTMLResponse(content=html_content, status_code=200)
32
 
33
  class ChatCompletionRequest(BaseModel):
34
+ messages: List[Any]
35
 
36
  @app.post("/v1/chat/completions")
37
  async def chat(request: ChatCompletionRequest, response_mode=None):
38
+ tokens = llm.tokenize(request.messages)
39
  async def server_sent_events(chat_chunks):
40
+ for token in llm.generate(chat_chunks):
41
+ yield llm.detokenize(token)
 
42
 
43
+ return EventSourceResponse(server_sent_events(tokens))
44
 
45
  if __name__ == "__main__":
46
  uvicorn.run(app, host="0.0.0.0", port=8000)