Inference / main.py
NOT-OMEGA's picture
Update main.py
66be360 verified
# main.py - SLM Inference Server
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import subprocess
import tiktoken
import os
import time
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 100
temperature: float = 0.8
top_k: int = 40
try:
enc = tiktoken.get_encoding("gpt2")
print("Tokenizer loaded successfully.")
except Exception as e:
print(f"Warning: tiktoken not found. Error: {e}")
enc = None
@app.get("/")
async def root():
current_dir = os.path.dirname(os.path.abspath(__file__))
return FileResponse(os.path.join(current_dir, "index.html"))
@app.get("/health")
async def health_check():
current_dir = os.path.dirname(os.path.abspath(__file__))
exe_path = os.path.join(current_dir, "inference")
model_path = os.path.join(current_dir, "model.bin")
return {
"status": "ok",
"inference_exe_found": os.path.exists(exe_path),
"model_bin_found": os.path.exists(model_path),
"working_directory": current_dir
}
@app.post("/generate")
async def generate_text(req: GenerateRequest):
if enc is None:
raise HTTPException(status_code=500, detail="Tokenizer not loaded.")
input_tokens = enc.encode(req.prompt)
token_str = ",".join(map(str, input_tokens))
current_dir = os.path.dirname(os.path.abspath(__file__))
exe_path = os.path.join(current_dir, "inference")
model_path = os.path.join(current_dir, "model.bin")
if not os.path.exists(exe_path):
raise HTTPException(status_code=500, detail=f"inference binary not found: {exe_path}")
if not os.path.exists(model_path):
raise HTTPException(status_code=500, detail=f"model.bin not found: {model_path}")
try:
start_time = time.perf_counter()
process = subprocess.run(
[exe_path, token_str, str(req.max_tokens), str(req.temperature), str(req.top_k)],
capture_output=True,
text=True,
cwd=current_dir
)
elapsed_ms = (time.perf_counter() - start_time) * 1000
except Exception as e:
raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}")
if process.returncode != 0 and not process.stdout.strip():
stdout_msg = process.stdout.strip() if process.stdout else ""
stderr_msg = process.stderr.strip() if process.stderr else ""
raise HTTPException(status_code=500, detail=f"C++ Error | stdout: '{stdout_msg}' | stderr: '{stderr_msg}'")
try:
output_str = process.stdout.strip()
generated_ids = []
if output_str:
for x in output_str.split():
try:
generated_ids.append(int(x))
except ValueError:
pass
generated_text = enc.decode(generated_ids) if generated_ids else ""
tokens_out = len(generated_ids)
tokens_per_sec = round(tokens_out / (elapsed_ms / 1000), 2) if elapsed_ms > 0 else 0
return {
"prompt": req.prompt,
"generated_text": generated_text,
"tokens_in": len(input_tokens),
"tokens_out": tokens_out,
"latency_ms": round(elapsed_ms, 2),
"tokens_per_sec": tokens_per_sec
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Decoding error: {str(e)}")