matthoffner commited on
Commit
2b6fd3b
1 Parent(s): 2d555b9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -6
main.py CHANGED
@@ -1,15 +1,18 @@
1
- import fastapi
2
  import json
3
  import markdown
 
 
 
 
4
  import uvicorn
5
- from fastapi import HTTPException
6
- from fastapi.responses import HTMLResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from sse_starlette.sse import EventSourceResponse
9
- from starlette.responses import StreamingResponse
10
- from ctransformers import AutoModelForCausalLM
11
  from pydantic import BaseModel
12
- from typing import List, Dict, Any, Generator
13
 
14
 
15
  llm = AutoModelForCausalLM.from_pretrained("TheBloke/WizardCoder-15B-1.0-GGML",
@@ -88,6 +91,37 @@ async def chat(request: ChatCompletionRequest):
88
 
89
  return StreamingResponse(format_response(chat_chunks), media_type="text/event-stream")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @app.post("/v0/chat/completions")
92
  async def chat(request: ChatCompletionRequestV0, response_mode=None):
93
  tokens = llm.tokenize(request.prompt)
 
 
1
  import json
2
  import markdown
3
+ from typing import List, Dict, Any, Generator
4
+ from functools import partial
5
+
6
+ import fastapi
7
  import uvicorn
8
+ from fastapi import HTTPException, Depends, Request
9
+ from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from sse_starlette.sse import EventSourceResponse
12
+ from anyio import create_memory_object_stream, run_in_threadpool
13
+ from transformers import AutoModelForCausalLM
14
  from pydantic import BaseModel
15
+
16
 
17
 
18
  llm = AutoModelForCausalLM.from_pretrained("TheBloke/WizardCoder-15B-1.0-GGML",
 
91
 
92
  return StreamingResponse(format_response(chat_chunks), media_type="text/event-stream")
93
 
94
+ @app.post("/v2/chat/completions")
95
+ async def chatV2(request: Request, body: ChatCompletionRequest):
96
+ combined_messages = ' '.join([message.content for message in body.messages])
97
+ tokens = llm.tokenize(combined_messages)
98
+
99
+ send_chan, recv_chan = create_memory_object_stream(10)
100
+
101
+ async def event_publisher(inner_send_chan):
102
+ async with inner_send_chan:
103
+ try:
104
+ iterator: Generator = await run_in_threadpool(llm.generate, tokens)
105
+ for chat_chunk in iterator:
106
+ response = {
107
+ 'choices': [
108
+ {
109
+ 'message': {
110
+ 'role': 'system',
111
+ 'content': llm.detokenize(chat_chunk)
112
+ },
113
+ 'finish_reason': 'stop' if llm.detokenize(chat_chunk) == "[DONE]" else 'unknown'
114
+ }
115
+ ]
116
+ }
117
+ await inner_send_chan.send(f"data: {json.dumps(response)}\n\n")
118
+ await inner_send_chan.send("event: done\ndata: {}\n\n")
119
+ except Exception as e:
120
+ print(f"Exception in event publisher: {str(e)}")
121
+
122
+ return StreamingResponse(recv_chan, media_type="text/event-stream", data_sender_callable=partial(event_publisher, send_chan))
123
+
124
+
125
  @app.post("/v0/chat/completions")
126
  async def chat(request: ChatCompletionRequestV0, response_mode=None):
127
  tokens = llm.tokenize(request.prompt)