matthoffner commited on
Commit
0d49ac1
1 Parent(s): 66c9b7e

Match openai completions api

Browse files
Files changed (1) hide show
  1. main.py +37 -1
main.py CHANGED
@@ -2,11 +2,14 @@ import fastapi
2
  import json
3
  import markdown
4
  import uvicorn
5
- from fastapi.responses import HTMLResponse
 
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from sse_starlette.sse import EventSourceResponse
8
  from ctransformers import AutoModelForCausalLM
9
  from pydantic import BaseModel
 
 
10
 
11
  llm = AutoModelForCausalLM.from_pretrained("TheBloke/WizardCoder-15B-1.0-GGML",
12
  model_file="WizardCoder-15B-1.0.ggmlv3.q4_0.bin",
@@ -43,11 +46,44 @@ async def index():
43
  class ChatCompletionRequest(BaseModel):
44
  prompt: str
45
 
 
 
 
 
 
 
 
 
46
  @app.post("/v1/completions")
47
  async def completion(request: ChatCompletionRequest, response_mode=None):
48
  response = llm(request.prompt)
49
  return response
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.post("/v1/chat/completions")
52
  async def chat(request: ChatCompletionRequest, response_mode=None):
53
  tokens = llm.tokenize(request.prompt)
 
2
  import json
3
  import markdown
4
  import uvicorn
5
+ from fastapi import HTTPException
6
+ from fastapi.responses import HTMLResponse, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from sse_starlette.sse import EventSourceResponse
9
  from ctransformers import AutoModelForCausalLM
10
  from pydantic import BaseModel
11
+ from typing import List, Dict, Any
12
+
13
 
14
  llm = AutoModelForCausalLM.from_pretrained("TheBloke/WizardCoder-15B-1.0-GGML",
15
  model_file="WizardCoder-15B-1.0.ggmlv3.q4_0.bin",
 
46
  class ChatCompletionRequest(BaseModel):
47
  prompt: str
48
 
49
+ class Message(BaseModel):
50
+ role: str
51
+ content: str
52
+
53
+ class ChatCompletionRequestV2(BaseModel):
54
+ messages: List[Message]
55
+ max_tokens: int = 100
56
+
57
  @app.post("/v1/completions")
58
  async def completion(request: ChatCompletionRequest, response_mode=None):
59
  response = llm(request.prompt)
60
  return response
61
 
62
+ @app.post("/v2/chat/completions")
63
+ async def chat(request: ChatCompletionRequestV2):
64
+ tokens = llm.tokenize([message.content for message in request.messages])
65
+
66
+ try:
67
+ chat_chunks = llm.generate(tokens, max_tokens=request.max_tokens)
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=str(e))
70
+
71
+ def format_response(chat_chunks) -> Dict[str, Any]:
72
+ response = {
73
+ 'choices': []
74
+ }
75
+ for chat_chunk in chat_chunks:
76
+ response['choices'].append({
77
+ 'message': {
78
+ 'role': 'system',
79
+ 'content': llm.detokenize(chat_chunk)
80
+ },
81
+ 'finish_reason': 'stop' if llm.detokenize(chat_chunk) == "[DONE]" else 'unknown'
82
+ })
83
+ return response
84
+
85
+ return format_response(chat_chunks)
86
+
87
  @app.post("/v1/chat/completions")
88
  async def chat(request: ChatCompletionRequest, response_mode=None):
89
  tokens = llm.tokenize(request.prompt)