matthoffner commited on
Commit
3c54391
1 Parent(s): 9f2e161

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -9
main.py CHANGED
@@ -5,11 +5,13 @@ import uvicorn
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
- from ctransformers.langchain import CTransformers
9
  from pydantic import BaseModel
10
 
11
- llm = CTransformers(model='ggml-model-q4_0.bin', model_type='starcoder')
12
- app = fastapi.FastAPI(title="Santacoder")
 
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -30,13 +32,14 @@ class ChatCompletionRequest(BaseModel):
30
 
31
  @app.post("/v1/chat/completions")
32
  async def chat(request: ChatCompletionRequest, response_mode=None):
33
- completion = llm(request.prompt)
34
- async def server_sent_events(chat_chunks):
35
- for chat_chunk in chat_chunks:
36
- yield dict(data=json.dumps(chat_chunk))
37
- yield dict(data="[DONE]")
 
38
 
39
- return EventSourceResponse(server_sent_events(completion))
40
 
41
  if __name__ == "__main__":
42
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
+ from ctransformers import AutoModelForCausalLM
9
  from pydantic import BaseModel
10
 
11
+ llm = AutoModelForCausalLM.from_pretrained("TheBloke/WizardCoder-15B-1.0-GGML",
12
+ model_file="WizardCoder-15B-1.0.ggmlv3.q4_0.bin",
13
+ model_type="starcoder")
14
+ app = fastapi.FastAPI(title="WizardCoder")
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
 
32
 
33
  @app.post("/v1/chat/completions")
34
  async def chat(request: ChatCompletionRequest, response_mode=None):
35
+ tokens = llm.tokenize(prompt)
36
+ async def server_sent_events(chat_chunks, llm):
37
+ yield prompt
38
+ for chat_chunk in llm.generate(chat_chunks):
39
+ yield llm.detokenize(chat_chunk)
40
+ yield ""
41
 
42
+ return EventSourceResponse(server_sent_events(tokens, llm))
43
 
44
  if __name__ == "__main__":
45
  uvicorn.run(app, host="0.0.0.0", port=8000)