|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
OpenLLM Inference Server
|
|
|
|
This script implements the REST API server for OpenLLM model inference
|
|
as specified in Step 6 of the training pipeline.
|
|
|
|
Features:
|
|
- FastAPI-based REST API
|
|
- Support for multiple model formats (PyTorch, Hugging Face, ONNX)
|
|
- Text generation with configurable parameters
|
|
- Health checks and metrics
|
|
- Production-ready deployment
|
|
|
|
Usage:
|
|
python core/src/inference_server.py \
|
|
--model_path exports/huggingface/ \
|
|
--host 0.0.0.0 \
|
|
--port 8000 \
|
|
--max_length 512
|
|
|
|
API Endpoints:
|
|
POST /generate - Generate text from prompt
|
|
GET /health - Health check
|
|
GET /info - Model information
|
|
|
|
Author: Louis Chua Bean Chong
|
|
License: GPLv3
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import uvicorn
|
|
|
|
|
|
try:
|
|
from fastapi import BackgroundTasks, FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
except ImportError:
|
|
raise ImportError("Install FastAPI: pip install fastapi uvicorn[standard]")
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
import numpy as np
|
|
import sentencepiece as smp
|
|
import torch
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from model import create_model
|
|
|
|
|
|
class TextGenerationConfig(BaseModel):
|
|
"""Configuration for text generation parameters."""
|
|
|
|
max_new_tokens: int = Field(
|
|
256, description="Maximum number of tokens to generate", ge=1, le=2048
|
|
)
|
|
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
|
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
|
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
|
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
|
|
stop_sequences: Optional[List[str]] = Field(
|
|
None, description="Stop generation at these sequences"
|
|
)
|
|
|
|
|
|
class GenerationRequest(BaseModel):
|
|
"""Request model for text generation."""
|
|
|
|
prompt: str = Field(..., description="Input text prompt")
|
|
max_length: int = Field(256, description="Maximum generation length", ge=1, le=2048)
|
|
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
|
top_k: Optional[int] = Field(40, description="Top-k sampling parameter", ge=1, le=1000)
|
|
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter", ge=0.1, le=1.0)
|
|
num_return_sequences: int = Field(1, description="Number of sequences to generate", ge=1, le=5)
|
|
stop_sequences: Optional[List[str]] = Field(
|
|
None, description="Stop generation at these sequences"
|
|
)
|
|
|
|
|
|
class GenerationResponse(BaseModel):
|
|
"""Response model for text generation."""
|
|
|
|
generated_text: List[str] = Field(..., description="Generated text sequences")
|
|
prompt: str = Field(..., description="Original prompt")
|
|
generation_time: float = Field(..., description="Generation time in seconds")
|
|
parameters: Dict[str, Any] = Field(..., description="Generation parameters used")
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Model information response."""
|
|
|
|
model_name: str
|
|
model_size: str
|
|
parameters: int
|
|
vocab_size: int
|
|
max_length: int
|
|
format: str
|
|
loaded_at: str
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
|
|
status: str
|
|
model_loaded: bool
|
|
uptime_seconds: float
|
|
total_requests: int
|
|
|
|
|
|
class OpenLLMInference:
|
|
"""
|
|
OpenLLM model inference engine.
|
|
|
|
Supports multiple model formats and provides text generation capabilities.
|
|
"""
|
|
|
|
def __init__(self, model_path: str, model_format: str = "auto"):
|
|
"""
|
|
Initialize inference engine.
|
|
|
|
Args:
|
|
model_path: Path to exported model directory
|
|
model_format: Model format (pytorch, huggingface, onnx, auto)
|
|
"""
|
|
self.model_path = Path(model_path)
|
|
self.model_format = model_format
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.config = None
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
self._load_model()
|
|
|
|
|
|
self.loaded_at = time.time()
|
|
self.total_requests = 0
|
|
|
|
print("🚀 OpenLLM Inference Engine initialized")
|
|
print(f" Model: {self.config.get('model_name', 'Unknown')}")
|
|
print(f" Format: {self.detected_format}")
|
|
print(f" Device: {self.device}")
|
|
|
|
def _detect_format(self) -> str:
|
|
"""Auto-detect model format from directory contents."""
|
|
if (self.model_path / "model.pt").exists():
|
|
return "pytorch"
|
|
elif (self.model_path / "pytorch_model.bin").exists():
|
|
return "huggingface"
|
|
elif (self.model_path / "model.onnx").exists():
|
|
return "onnx"
|
|
else:
|
|
raise ValueError(f"Could not detect model format in {self.model_path}")
|
|
|
|
def _load_model(self):
|
|
"""Load model based on detected format."""
|
|
if self.model_format == "auto":
|
|
self.detected_format = self._detect_format()
|
|
else:
|
|
self.detected_format = self.model_format
|
|
|
|
print(f"📂 Loading {self.detected_format} model from {self.model_path}")
|
|
|
|
if self.detected_format == "pytorch":
|
|
self._load_pytorch_model()
|
|
elif self.detected_format == "huggingface":
|
|
self._load_huggingface_model()
|
|
elif self.detected_format == "onnx":
|
|
self._load_onnx_model()
|
|
else:
|
|
raise ValueError(f"Unsupported format: {self.detected_format}")
|
|
|
|
|
|
self._load_tokenizer()
|
|
|
|
print("✅ Model loaded successfully")
|
|
|
|
def _load_pytorch_model(self):
|
|
"""Load PyTorch format model."""
|
|
|
|
with open(self.model_path / "config.json", "r") as f:
|
|
config_data = json.load(f)
|
|
|
|
self.config = config_data["model_config"]
|
|
|
|
|
|
checkpoint = torch.load(self.model_path / "model.pt", map_location=self.device)
|
|
|
|
|
|
n_layer = self.config.get("n_layer", 12)
|
|
if n_layer <= 6:
|
|
model_size = "small"
|
|
elif n_layer <= 12:
|
|
model_size = "medium"
|
|
else:
|
|
model_size = "large"
|
|
|
|
|
|
self.model = create_model(model_size)
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
def _load_huggingface_model(self):
|
|
"""Load Hugging Face format model."""
|
|
|
|
with open(self.model_path / "config.json", "r") as f:
|
|
self.config = json.load(f)
|
|
|
|
|
|
state_dict = torch.load(self.model_path / "pytorch_model.bin", map_location=self.device)
|
|
|
|
|
|
n_layer = self.config.get("n_layer", 12)
|
|
if n_layer <= 6:
|
|
model_size = "small"
|
|
elif n_layer <= 12:
|
|
model_size = "medium"
|
|
else:
|
|
model_size = "large"
|
|
|
|
|
|
self.model = create_model(model_size)
|
|
self.model.load_state_dict(state_dict)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
def _load_onnx_model(self):
|
|
"""Load ONNX format model."""
|
|
try:
|
|
import onnxruntime as ort
|
|
except ImportError:
|
|
raise ImportError("ONNX inference requires: pip install onnxruntime")
|
|
|
|
|
|
model_file = self.model_path / "model.onnx"
|
|
if not model_file.exists():
|
|
raise FileNotFoundError(f"ONNX model not found: {model_file}")
|
|
|
|
|
|
if not str(model_file).startswith(str(self.model_path)):
|
|
raise ValueError(f"Invalid model path: {model_file}")
|
|
|
|
|
|
metadata_file = self.model_path / "metadata.json"
|
|
if not metadata_file.exists():
|
|
raise FileNotFoundError(f"ONNX metadata not found: {metadata_file}")
|
|
|
|
with open(metadata_file, "r") as f:
|
|
metadata = json.load(f)
|
|
|
|
self.config = metadata["model_config"]
|
|
|
|
|
|
providers = (
|
|
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
if torch.cuda.is_available()
|
|
else ["CPUExecutionProvider"]
|
|
)
|
|
|
|
|
|
session_options = ort.SessionOptions()
|
|
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
session_options.enable_mem_pattern = False
|
|
session_options.enable_cpu_mem_arena = False
|
|
|
|
self.onnx_session = ort.InferenceSession(
|
|
str(model_file), providers=providers, sess_options=session_options
|
|
)
|
|
|
|
|
|
self.device = "onnx"
|
|
|
|
def _load_tokenizer(self):
|
|
"""Load tokenizer."""
|
|
tokenizer_path = self.model_path / "tokenizer.model"
|
|
if not tokenizer_path.exists():
|
|
raise FileNotFoundError(f"Tokenizer not found: {tokenizer_path}")
|
|
|
|
self.tokenizer = smp.SentencePieceProcessor()
|
|
self.tokenizer.load(str(tokenizer_path))
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
max_length: int = 256,
|
|
temperature: float = 0.7,
|
|
top_k: Optional[int] = 40,
|
|
top_p: Optional[float] = 0.9,
|
|
num_return_sequences: int = 1,
|
|
stop_sequences: Optional[List[str]] = None,
|
|
) -> List[str]:
|
|
"""
|
|
Generate text from prompt.
|
|
|
|
Args:
|
|
prompt: Input text prompt
|
|
max_length: Maximum generation length
|
|
temperature: Sampling temperature
|
|
top_k: Top-k sampling parameter
|
|
top_p: Nucleus sampling parameter
|
|
num_return_sequences: Number of sequences to generate
|
|
stop_sequences: Stop generation at these sequences
|
|
|
|
Returns:
|
|
List of generated text sequences
|
|
"""
|
|
self.total_requests += 1
|
|
|
|
if self.detected_format == "onnx":
|
|
return self._generate_onnx(
|
|
prompt, max_length, temperature, top_k, num_return_sequences, stop_sequences
|
|
)
|
|
else:
|
|
return self._generate_pytorch(
|
|
prompt, max_length, temperature, top_k, top_p, num_return_sequences, stop_sequences
|
|
)
|
|
|
|
def _generate_pytorch(
|
|
self,
|
|
prompt: str,
|
|
max_length: int,
|
|
temperature: float,
|
|
top_k: Optional[int],
|
|
top_p: Optional[float],
|
|
num_return_sequences: int,
|
|
stop_sequences: Optional[List[str]],
|
|
) -> List[str]:
|
|
"""Generate using PyTorch model."""
|
|
|
|
input_ids = self.tokenizer.encode(prompt)
|
|
input_tensor = torch.tensor(
|
|
[input_ids] * num_return_sequences, dtype=torch.long, device=self.device
|
|
)
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = []
|
|
for _ in range(num_return_sequences):
|
|
|
|
if hasattr(self.model, "generate"):
|
|
output = self.model.generate(
|
|
input_tensor[:1],
|
|
max_new_tokens=max_length,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
)
|
|
generated_ids = output[0].tolist()
|
|
generated_text = self.tokenizer.decode(generated_ids[len(input_ids) :])
|
|
else:
|
|
|
|
generated_text = self._simple_generate(
|
|
input_tensor[:1], max_length, temperature
|
|
)
|
|
|
|
|
|
if stop_sequences:
|
|
for stop_seq in stop_sequences:
|
|
if stop_seq in generated_text:
|
|
generated_text = generated_text.split(stop_seq)[0]
|
|
break
|
|
|
|
outputs.append(generated_text)
|
|
|
|
return outputs
|
|
|
|
def _generate_onnx(
|
|
self,
|
|
prompt: str,
|
|
max_length: int,
|
|
temperature: float,
|
|
top_k: Optional[int],
|
|
num_return_sequences: int,
|
|
stop_sequences: Optional[List[str]],
|
|
) -> List[str]:
|
|
"""Generate using ONNX model."""
|
|
outputs = []
|
|
|
|
for _ in range(num_return_sequences):
|
|
|
|
tokens = self.tokenizer.encode(prompt)
|
|
generated = tokens.copy()
|
|
|
|
|
|
for _ in range(max_length):
|
|
if len(generated) >= 512:
|
|
break
|
|
|
|
|
|
current_input = np.array([generated[-64:]], dtype=np.int64)
|
|
|
|
|
|
logits = self.onnx_session.run(None, {"input_ids": current_input})[0]
|
|
next_token_logits = logits[0, -1, :]
|
|
|
|
|
|
if temperature > 0:
|
|
next_token_logits = next_token_logits / temperature
|
|
probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
|
|
|
|
|
|
if top_k:
|
|
top_indices = np.argpartition(probs, -top_k)[-top_k:]
|
|
probs_filtered = np.zeros_like(probs)
|
|
probs_filtered[top_indices] = probs[top_indices]
|
|
probs = probs_filtered / np.sum(probs_filtered)
|
|
|
|
next_token = np.random.choice(len(probs), p=probs)
|
|
else:
|
|
next_token = np.argmax(next_token_logits)
|
|
|
|
generated.append(int(next_token))
|
|
|
|
|
|
generated_text = self.tokenizer.decode(generated[len(tokens) :])
|
|
|
|
|
|
if stop_sequences:
|
|
for stop_seq in stop_sequences:
|
|
if stop_seq in generated_text:
|
|
generated_text = generated_text.split(stop_seq)[0]
|
|
break
|
|
|
|
outputs.append(generated_text)
|
|
|
|
return outputs
|
|
|
|
def _simple_generate(
|
|
self, input_tensor: torch.Tensor, max_length: int, temperature: float
|
|
) -> str:
|
|
"""Simple fallback generation method."""
|
|
generated = input_tensor[0].tolist()
|
|
|
|
for _ in range(max_length):
|
|
if len(generated) >= self.config.get("block_size", 1024):
|
|
break
|
|
|
|
|
|
current_input = torch.tensor([generated], dtype=torch.long, device=self.device)
|
|
with torch.no_grad():
|
|
logits, _ = self.model(current_input)
|
|
|
|
|
|
next_token_logits = logits[0, -1, :] / temperature
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1).item()
|
|
|
|
generated.append(next_token)
|
|
|
|
|
|
original_length = input_tensor.size(1)
|
|
generated_tokens = generated[original_length:]
|
|
return self.tokenizer.decode(generated_tokens)
|
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
"""Get model information."""
|
|
return {
|
|
"model_name": self.config.get("model_name", "OpenLLM"),
|
|
"model_size": self.config.get("model_size", "unknown"),
|
|
"parameters": self.config.get("n_embd", 0)
|
|
* self.config.get("n_layer", 0),
|
|
"vocab_size": self.config.get("vocab_size", self.tokenizer.vocab_size()),
|
|
"max_length": self.config.get("block_size", 1024),
|
|
"format": self.detected_format,
|
|
"loaded_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.loaded_at)),
|
|
}
|
|
|
|
def get_health(self) -> Dict[str, Any]:
|
|
"""Get health status."""
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": self.model is not None,
|
|
"uptime_seconds": time.time() - self.loaded_at,
|
|
"total_requests": self.total_requests,
|
|
}
|
|
|
|
|
|
|
|
inference_engine: Optional[OpenLLMInference] = None
|
|
|
|
|
|
app = FastAPI(
|
|
title="OpenLLM Inference API",
|
|
description="REST API for OpenLLM text generation",
|
|
version="0.1.0",
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
)
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Initialize inference engine on startup."""
|
|
print("🚀 Starting OpenLLM Inference Server...")
|
|
|
|
|
|
global inference_engine
|
|
if inference_engine is None:
|
|
print("⚠️ No model loaded - server will return 503 for generation requests")
|
|
print(" Use main() function to load a real model")
|
|
print(" For testing, use load_model_for_testing() function")
|
|
|
|
|
|
@app.post("/generate", response_model=GenerationResponse)
|
|
async def generate_text(request: GenerationRequest, background_tasks: BackgroundTasks):
|
|
"""Generate text from prompt."""
|
|
if inference_engine is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
|
|
generated_texts = inference_engine.generate(
|
|
prompt=request.prompt,
|
|
max_length=request.max_length,
|
|
temperature=request.temperature,
|
|
top_k=request.top_k,
|
|
top_p=request.top_p,
|
|
num_return_sequences=request.num_return_sequences,
|
|
stop_sequences=request.stop_sequences,
|
|
)
|
|
|
|
generation_time = time.time() - start_time
|
|
|
|
return GenerationResponse(
|
|
generated_text=generated_texts,
|
|
prompt=request.prompt,
|
|
generation_time=generation_time,
|
|
parameters={
|
|
"max_length": request.max_length,
|
|
"temperature": request.temperature,
|
|
"top_k": request.top_k,
|
|
"top_p": request.top_p,
|
|
"num_return_sequences": request.num_return_sequences,
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
@app.post("/generate/stream")
|
|
async def generate_text_stream(request: GenerationRequest):
|
|
"""Generate text with streaming response."""
|
|
if inference_engine is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
try:
|
|
|
|
|
|
generated_texts = inference_engine.generate(
|
|
prompt=request.prompt,
|
|
max_length=request.max_length,
|
|
temperature=request.temperature,
|
|
top_k=request.top_k,
|
|
top_p=request.top_p,
|
|
num_return_sequences=request.num_return_sequences,
|
|
stop_sequences=request.stop_sequences,
|
|
)
|
|
|
|
|
|
return {
|
|
"generated_text": generated_texts,
|
|
"prompt": request.prompt,
|
|
"streaming": True,
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
|
|
|
|
|
@app.get("/info", response_model=ModelInfo)
|
|
async def get_model_info():
|
|
"""Get model information."""
|
|
if inference_engine is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
info = inference_engine.get_info()
|
|
return ModelInfo(**info)
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
if inference_engine is None:
|
|
return HealthResponse(
|
|
status="unhealthy", model_loaded=False, uptime_seconds=0.0, total_requests=0
|
|
)
|
|
|
|
health = inference_engine.get_health()
|
|
return HealthResponse(**health)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint."""
|
|
return {
|
|
"message": "OpenLLM Inference API",
|
|
"version": "0.1.0",
|
|
"docs": "/docs",
|
|
"health": "/health",
|
|
"info": "/info",
|
|
"endpoints": ["/generate", "/generate/stream", "/health", "/info"],
|
|
}
|
|
|
|
|
|
def main():
|
|
"""Main server function."""
|
|
parser = argparse.ArgumentParser(
|
|
description="OpenLLM Inference Server",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Start server with Hugging Face model
|
|
python core/src/inference_server.py \\
|
|
--model_path exports/huggingface/ \\
|
|
--host 0.0.0.0 \\
|
|
--port 8000
|
|
|
|
# Start server with ONNX model
|
|
python core/src/inference_server.py \\
|
|
--model_path exports/onnx/ \\
|
|
--format onnx \\
|
|
--port 8001
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model_path",
|
|
required=True,
|
|
help="Path to exported model directory",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--format",
|
|
choices=["pytorch", "huggingface", "onnx", "auto"],
|
|
default="auto",
|
|
help="Model format (default: auto-detect)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--host",
|
|
default="127.0.0.1",
|
|
help="Host to bind to (default: 127.0.0.1)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=8000,
|
|
help="Port to bind to (default: 8000)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max_length",
|
|
type=int,
|
|
default=512,
|
|
help="Maximum generation length (default: 512)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
global inference_engine
|
|
inference_engine = OpenLLMInference(args.model_path, args.format)
|
|
|
|
|
|
print(f"🚀 Starting server on {args.host}:{args.port}")
|
|
uvicorn.run(
|
|
app,
|
|
host=args.host,
|
|
port=args.port,
|
|
log_level="info",
|
|
)
|
|
|
|
|
|
def load_model(model_path: str, model_format: str = "auto"):
|
|
"""
|
|
Load model for testing purposes.
|
|
|
|
This function is used by tests to load models without starting the full server.
|
|
|
|
Args:
|
|
model_path: Path to exported model directory
|
|
model_format: Model format (pytorch, huggingface, onnx, auto)
|
|
|
|
Returns:
|
|
OpenLLMInference: Initialized inference engine
|
|
"""
|
|
return OpenLLMInference(model_path, model_format)
|
|
|
|
|
|
def load_model_for_testing(
|
|
model_path: str = "exports/huggingface", model_format: str = "huggingface"
|
|
):
|
|
"""
|
|
Load a real model for testing purposes.
|
|
|
|
This function loads the actual trained model for testing.
|
|
|
|
Args:
|
|
model_path: Path to the model directory (default: exports/huggingface)
|
|
model_format: Model format (default: huggingface)
|
|
|
|
Returns:
|
|
OpenLLMInference: Real inference engine with loaded model
|
|
"""
|
|
global inference_engine
|
|
try:
|
|
inference_engine = OpenLLMInference(model_path, model_format)
|
|
print(f"✅ Real model loaded for testing from {model_path}")
|
|
return inference_engine
|
|
except Exception as e:
|
|
print(f"❌ Failed to load real model: {e}")
|
|
|
|
return create_test_model()
|
|
|
|
|
|
def create_test_model():
|
|
"""
|
|
Create a real lightweight test model for testing purposes.
|
|
|
|
This creates a real model with minimal parameters for testing,
|
|
without requiring large model files to be downloaded.
|
|
|
|
Returns:
|
|
OpenLLMInference: Real lightweight inference engine
|
|
"""
|
|
try:
|
|
|
|
import sentencepiece as smp
|
|
from model import GPTConfig, GPTModel
|
|
|
|
|
|
config = GPTConfig.small()
|
|
config.n_embd = 128
|
|
config.n_layer = 2
|
|
config.vocab_size = 1000
|
|
config.block_size = 64
|
|
|
|
|
|
model = GPTModel(config)
|
|
model.eval()
|
|
|
|
|
|
class MinimalTokenizer:
|
|
def __init__(self):
|
|
self.vocab_size = 1000
|
|
|
|
def encode(self, text):
|
|
|
|
return [ord(c) % 1000 for c in text[:50]]
|
|
|
|
def decode(self, tokens):
|
|
|
|
return "".join([chr(t % 256) for t in tokens if t < 256])
|
|
|
|
def vocab_size(self):
|
|
return 1000
|
|
|
|
|
|
class LightweightInferenceEngine:
|
|
def __init__(self):
|
|
self.model = model
|
|
self.tokenizer = MinimalTokenizer()
|
|
self.config = {
|
|
"model_name": "openllm-small-test",
|
|
"model_size": "small",
|
|
"n_embd": config.n_embd,
|
|
"n_layer": config.n_layer,
|
|
"vocab_size": config.vocab_size,
|
|
"block_size": config.block_size,
|
|
}
|
|
self.detected_format = "pytorch"
|
|
self.device = "cpu"
|
|
self.loaded_at = time.time()
|
|
self.total_requests = 0
|
|
|
|
def generate(self, prompt, max_length=10, temperature=0.7, **kwargs):
|
|
"""Real text generation with lightweight model."""
|
|
self.total_requests += 1
|
|
|
|
|
|
input_ids = self.tokenizer.encode(prompt)
|
|
if len(input_ids) == 0:
|
|
input_ids = [1]
|
|
|
|
|
|
generated = input_ids.copy()
|
|
for _ in range(max_length):
|
|
if len(generated) >= self.config["block_size"]:
|
|
break
|
|
|
|
|
|
input_tensor = torch.tensor([generated], dtype=torch.long)
|
|
|
|
|
|
with torch.no_grad():
|
|
logits, _ = self.model(input_tensor)
|
|
|
|
|
|
next_token_logits = logits[0, -1, :] / temperature
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1).item()
|
|
|
|
generated.append(next_token)
|
|
|
|
|
|
generated_text = self.tokenizer.decode(generated[len(input_ids) :])
|
|
return [generated_text]
|
|
|
|
def get_info(self):
|
|
"""Get real model information."""
|
|
return {
|
|
"model_name": "openllm-small-test",
|
|
"model_size": "small",
|
|
"parameters": config.n_embd * config.n_layer * 1000,
|
|
"vocab_size": config.vocab_size,
|
|
"max_length": config.block_size,
|
|
"format": "pytorch",
|
|
"loaded_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.loaded_at)),
|
|
}
|
|
|
|
def get_health(self):
|
|
"""Get real health status."""
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": True,
|
|
"uptime_seconds": time.time() - self.loaded_at,
|
|
"total_requests": self.total_requests,
|
|
}
|
|
|
|
return LightweightInferenceEngine()
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Failed to create lightweight model: {e}")
|
|
|
|
|
|
class SimpleMockInferenceEngine:
|
|
def __init__(self):
|
|
self.model = "simple_mock"
|
|
self.tokenizer = "simple_mock"
|
|
self.config = {"model_name": "fallback-model"}
|
|
self.detected_format = "pytorch"
|
|
self.device = "cpu"
|
|
self.loaded_at = time.time()
|
|
self.total_requests = 0
|
|
|
|
def generate(self, prompt, **kwargs):
|
|
self.total_requests += 1
|
|
return [f"Generated: {prompt[:10]}..."]
|
|
|
|
def get_info(self):
|
|
return {
|
|
"model_name": "fallback-model",
|
|
"model_size": "small",
|
|
"parameters": 1000,
|
|
"vocab_size": 1000,
|
|
"max_length": 100,
|
|
"format": "pytorch",
|
|
"loaded_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.loaded_at)),
|
|
}
|
|
|
|
def get_health(self):
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": True,
|
|
"uptime_seconds": time.time() - self.loaded_at,
|
|
"total_requests": self.total_requests,
|
|
}
|
|
|
|
return SimpleMockInferenceEngine()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|