matthoffner commited on
Commit
bf31376
1 Parent(s): 25e9e92

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -9,7 +9,8 @@ from fastapi import HTTPException, Depends, Request
9
  from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from sse_starlette.sse import EventSourceResponse
12
- from anyio import create_memory_object_stream, run_in_threadpool
 
13
  from transformers import AutoModelForCausalLM
14
  from pydantic import BaseModel
15
 
@@ -101,7 +102,7 @@ async def chatV2(request: Request, body: ChatCompletionRequest):
101
  async def event_publisher(inner_send_chan):
102
  async with inner_send_chan:
103
  try:
104
- iterator: Generator = await run_in_threadpool(llm.generate, tokens)
105
  for chat_chunk in iterator:
106
  response = {
107
  'choices': [
@@ -118,10 +119,9 @@ async def chatV2(request: Request, body: ChatCompletionRequest):
118
  await inner_send_chan.send("event: done\ndata: {}\n\n")
119
  except Exception as e:
120
  print(f"Exception in event publisher: {str(e)}")
121
-
122
  return StreamingResponse(recv_chan, media_type="text/event-stream", data_sender_callable=partial(event_publisher, send_chan))
123
 
124
-
125
  @app.post("/v0/chat/completions")
126
  async def chat(request: ChatCompletionRequestV0, response_mode=None):
127
  tokens = llm.tokenize(request.prompt)
 
9
  from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from sse_starlette.sse import EventSourceResponse
12
+ from anyio import create_memory_object_stream
13
+ from anyio.to_thread import run_sync
14
  from transformers import AutoModelForCausalLM
15
  from pydantic import BaseModel
16
 
 
102
  async def event_publisher(inner_send_chan):
103
  async with inner_send_chan:
104
  try:
105
+ iterator: Generator = await run_sync(llm.generate, tokens)
106
  for chat_chunk in iterator:
107
  response = {
108
  'choices': [
 
119
  await inner_send_chan.send("event: done\ndata: {}\n\n")
120
  except Exception as e:
121
  print(f"Exception in event publisher: {str(e)}")
122
+
123
  return StreamingResponse(recv_chan, media_type="text/event-stream", data_sender_callable=partial(event_publisher, send_chan))
124
 
 
125
  @app.post("/v0/chat/completions")
126
  async def chat(request: ChatCompletionRequestV0, response_mode=None):
127
  tokens = llm.tokenize(request.prompt)