|
|
import asyncio |
|
|
import json |
|
|
import time |
|
|
|
|
|
from typing import Optional, List |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from starlette.responses import StreamingResponse |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
|
|
|
app = FastAPI(title="OpenAI-compatible API") |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") |
|
|
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") |
|
|
|
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: Optional[str] = "mock-gpt-model" |
|
|
messages: List[Message] |
|
|
max_tokens: Optional[int] = 512 |
|
|
temperature: Optional[float] = 0.1 |
|
|
stream: Optional[bool] = False |
|
|
|
|
|
|
|
|
async def _resp_async_generator(text_resp: str, model: str): |
|
|
tokens = text_resp.split(" ") |
|
|
|
|
|
for i, token in enumerate(tokens): |
|
|
chunk = { |
|
|
"id": i, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": time.time(), |
|
|
"model": model, |
|
|
"choices": [{"delta": {"content": token + " "}}], |
|
|
} |
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
await asyncio.sleep(0.05) |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
@app.post("/chat/completions") |
|
|
async def chat_completions(request: ChatCompletionRequest): |
|
|
if not request.messages: |
|
|
raise HTTPException(status_code=400, detail="No messages provided.") |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
for msg in request.messages: |
|
|
if msg.role == "user": |
|
|
prompt += f"User: {msg.content}\n" |
|
|
elif msg.role == "assistant": |
|
|
prompt += f"Assistant: {msg.content}\n" |
|
|
prompt += "Assistant:" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
assistant_reply = generated_text[len(prompt):].strip() |
|
|
|
|
|
if request.stream: |
|
|
return StreamingResponse( |
|
|
_resp_async_generator(assistant_reply, request.model), |
|
|
media_type="text/event-stream" |
|
|
) |
|
|
|
|
|
return { |
|
|
"id": "1337", |
|
|
"object": "chat.completion", |
|
|
"created": time.time(), |
|
|
"model": request.model, |
|
|
"choices": [{"message": Message(role="assistant", content=assistant_reply)}], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|