gpt-oss-20b-new / handler.py
Konstantin
add files and requirements
ab52e38
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import uuid
import time
import json
from threading import Thread
class EndpointHandler:
def __init__(self, path: str = "openai/gpt-oss-20b"):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
self.model.eval()
# Determine the computation device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def openai_id(prefix: str) -> str:
return f"{prefix}-{uuid.uuid4().hex[:24]}"
def format_non_stream(self, model: str, text: str, prompt_length: int, completion_length: int, total_tokens: int):
# Create OpenAI-compatible payload
return {
"id": self.openai_id("chatcmpl"),
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_length,
"completion_tokens": completion_length,
"total_tokens": total_tokens
}
}
def format_stream(self, model: str, token: str, usage) -> bytes:
payload = {
"id": self.openai_id("chatcmpl"),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": token,
"function_call": None,
"refusal": None,
"role": None,
"tool_calls": None
},
"finish_reason": None,
"logprobs": None
}],
"usage": usage
}
return f"data: {json.dumps(payload)}\n\n".encode('utf-8')
def generate(self, messages, model: str):
model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
full_output = self.model.generate(**model_inputs, max_new_tokens=2048)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, full_output)
]
text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
input_length = model_inputs.input_ids.shape[1] # Prompt tokens
output_length = full_output.shape[1] # Total tokens (prompt + completion)
completion_tokens = output_length - input_length
return self.format_non_stream(model, text, input_length, completion_tokens, output_length)
def stream(self, messages, model):
model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
input_len = model_inputs.input_ids.shape[1]
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=2048
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
completion_tokens = 0
for token in streamer:
# Count tokens in each chunk
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
token_count = len(token_ids)
completion_tokens += token_count
yield self.format_stream(model, token, None)
# Final chunk with stop reason and token counts
yield self.format_stream(model, "", {
"prompt_tokens": input_len,
"completion_tokens": completion_tokens,
"total_tokens": input_len + completion_tokens
})
def __call__(self, data: Dict[str, Any]):
messages = data.get("messages")
model = data.get("model")
stream = data.get("stream", False)
if stream is False:
return self.generate(messages, model)
else:
return StreamingResponse(
self.stream(messages, model),
media_type="text/event-stream"
)