|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from fastapi.responses import StreamingResponse |
|
import uuid |
|
import time |
|
import json |
|
from threading import Thread |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = "openai/gpt-oss-20b"): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForCausalLM.from_pretrained(path) |
|
self.model.eval() |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def openai_id(prefix: str) -> str: |
|
return f"{prefix}-{uuid.uuid4().hex[:24]}" |
|
|
|
def format_non_stream(self, model: str, text: str, prompt_length: int, completion_length: int, total_tokens: int): |
|
|
|
return { |
|
"id": self.openai_id("chatcmpl"), |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": model, |
|
"choices": [{ |
|
"index": 0, |
|
"message": {"role": "assistant", "content": text}, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": prompt_length, |
|
"completion_tokens": completion_length, |
|
"total_tokens": total_tokens |
|
} |
|
} |
|
|
|
def format_stream(self, model: str, token: str, usage) -> bytes: |
|
payload = { |
|
"id": self.openai_id("chatcmpl"), |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": model, |
|
"choices": [{ |
|
"index": 0, |
|
"delta": { |
|
"content": token, |
|
"function_call": None, |
|
"refusal": None, |
|
"role": None, |
|
"tool_calls": None |
|
}, |
|
"finish_reason": None, |
|
"logprobs": None |
|
}], |
|
"usage": usage |
|
} |
|
|
|
return f"data: {json.dumps(payload)}\n\n".encode('utf-8') |
|
|
|
def generate(self, messages, model: str): |
|
model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device) |
|
full_output = self.model.generate(**model_inputs, max_new_tokens=2048) |
|
generated_ids = [ |
|
output_ids[len(input_ids):] |
|
for input_ids, output_ids in zip(model_inputs.input_ids, full_output) |
|
] |
|
text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
|
|
input_length = model_inputs.input_ids.shape[1] |
|
output_length = full_output.shape[1] |
|
completion_tokens = output_length - input_length |
|
|
|
return self.format_non_stream(model, text, input_length, completion_tokens, output_length) |
|
|
|
def stream(self, messages, model): |
|
model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device) |
|
input_len = model_inputs.input_ids.shape[1] |
|
streamer = TextIteratorStreamer( |
|
self.tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True |
|
) |
|
|
|
generation_kwargs = dict( |
|
**model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=2048 |
|
) |
|
|
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
completion_tokens = 0 |
|
for token in streamer: |
|
|
|
token_ids = self.tokenizer.encode(token, add_special_tokens=False) |
|
token_count = len(token_ids) |
|
completion_tokens += token_count |
|
|
|
yield self.format_stream(model, token, None) |
|
|
|
|
|
yield self.format_stream(model, "", { |
|
"prompt_tokens": input_len, |
|
"completion_tokens": completion_tokens, |
|
"total_tokens": input_len + completion_tokens |
|
}) |
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
messages = data.get("messages") |
|
model = data.get("model") |
|
stream = data.get("stream", False) |
|
|
|
if stream is False: |
|
return self.generate(messages, model) |
|
else: |
|
return StreamingResponse( |
|
self.stream(messages, model), |
|
media_type="text/event-stream" |
|
) |