scriptify-api / main.py
henok3878
try the full model (no quantization)
91f531c
from typing import Optional
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
import torch.nn.functional as F
from pathlib import Path
import logging
import time
from contextlib import asynccontextmanager
from inference_utils import PrimingData, construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map, load_priming_data
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_DIR = Path("./packaged_models")
QUANTIZED_MODEL_NAME = "model.scripted.quantized.pt"
SCRIPTED_MODEL_NAME = "model.scripted.pt"
METADATA_MODEL_NAME = "model.pt"
SCRIPTED_MODEL: Optional[torch.jit.ScriptModule] = None
MODEL_METADATA: Optional[dict] = None
DEVICE: Optional[torch.device] = None
ALPHABET_MAP: Optional[dict[str, int]] = None
ALPHABET_LIST: Optional[list[str]] = None
ALPHABET_SIZE: Optional[int] = None
MAX_TEXT_LEN: Optional[int] = None
output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling
lstm_size: Optional[int] = None
attention_mixture_components: Optional[int] = None
# Patience for early stopping in generate_strokes
PATIENCE_PEN_UP_EOS = 15
MIN_MOVEMENT_THRESHOLD = 0.02
class HandwritingRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=40, description="Text to generate handwriting for")
max_length: int = Field(default=1000, ge=50, le=1200, description="Maximum number of stroke points")
bias: float = Field(default=0.75, ge=0.1, le=10.0, description="Sampling bias for generation")
class HandwritingResponse(BaseModel):
success: bool = True
input_text: str
generation_time_ms: float
num_points: int
strokes: list[list[float]]
message: str = "Successfully generated handwriting."
class HealthResponse(BaseModel):
status: str
model_loaded: bool
device: str
model_metadata_keys: Optional[list[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE
logger.info("Attempting to load model resources during startup")
try:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")
scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME
metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME
# if DEVICE.type == "cpu":
# scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME
if not scripted_model_path.exists():
logger.error(f"Traced model not found at {scripted_model_path}")
raise FileNotFoundError(f"Traced model not found at {scripted_model_path}")
if not metadata_model_path or not metadata_model_path.exists():
logger.error(f"Metadata model file not found at {metadata_model_path}")
raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}")
# Load the traced model
SCRIPTED_MODEL = torch.jit.load(scripted_model_path, map_location=DEVICE)
if SCRIPTED_MODEL:
SCRIPTED_MODEL.eval()
logger.info(f"Traced model loaded successfully from {scripted_model_path}")
# Load the metadata
MODEL_METADATA = torch.load(metadata_model_path, map_location='cpu')
if MODEL_METADATA:
logger.info(f"Model metadata loaded successfully from {metadata_model_path}")
logger.info(f"Model metadata keys: {list(MODEL_METADATA.keys())}")
config_full = MODEL_METADATA['config_full']
if not config_full or not isinstance(config_full, dict):
raise ValueError(f"Key `config_full` not found or not a dict")
dataset_config = config_full['dataset']
model_params = config_full['model_params']
if not dataset_config or not isinstance(dataset_config, dict):
raise ValueError(f"Key `dataset` not found or not a dict in config_full")
alphabet_str = dataset_config['alphabet_string']
MAX_TEXT_LEN = dataset_config['max_text_len']
output_mixture_components = model_params['output_mixture_components']
lstm_size = model_params['lstm_size']
attention_mixture_components = model_params['attention_mixture_components']
ALPHABET_LIST = construct_alphabet_list(alphabet_str)
ALPHABET_SIZE = len(ALPHABET_LIST)
ALPHABET_MAP = get_alphabet_map(ALPHABET_LIST)
logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}")
logger.info("Model resources are loaded and ready")
else:
raise ValueError(f"Failed to load content frm metadata file")
except Exception as e:
logger.error(f"Error loading model resources: {e}", exc_info=True)
SCRIPTED_MODEL = None
MODEL_METADATA = None
raise
yield
# Cleanup on shutdown
logger.info("Shutting down API and cleaning up resources")
SCRIPTED_MODEL = None
MODEL_METADATA = None
app = FastAPI(
title="Scriptify API",
description="API to generate handwriting from text using a PyTorch model.",
version="0.1.0",
lifespan=lifespan
)
# add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173","http://127.0.0.1:5173"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
@app.get("/", tags=["General"])
async def read_root():
return {"message": "Welcome to the Scriptify Handwriting Generation API!"}
@app.get("/health", response_model=HealthResponse, tags=["General"])
async def health_check():
global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST
is_healthy = all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST])
return HealthResponse(
status="healthy" if is_healthy else "unhealthy",
model_loaded=bool(SCRIPTED_MODEL),
device=str(DEVICE) if DEVICE else "unknown",
model_metadata_keys=list(MODEL_METADATA.keys()) if MODEL_METADATA else None,
)
def text_to_tensor(text: str, max_text_length: int, add_eos: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert text to tensor format expected by the model"""
if ALPHABET_MAP is None:
raise ValueError("Alphabet map not initialized during api startup")
padded_encoded_np, true_length = encode_text(
text=text,
char_to_index_map=ALPHABET_MAP,
max_length=max_text_length,
add_eos = add_eos
)
char_seq = torch.from_numpy(padded_encoded_np).to(device=DEVICE, dtype=torch.long)
char_len = torch.tensor([true_length], device=DEVICE, dtype=torch.long)
return char_seq, char_len
def generate_strokes(
char_seq: torch.Tensor,
char_lengths: torch.Tensor,
max_gen_len: int,
api_bias: float,
style: Optional[int] = None
) -> list[list[float]]:
"""Generate strokes using the model's built-in sample method"""
global SCRIPTED_MODEL
if SCRIPTED_MODEL is None:
raise ValueError("Scripted model not initialized.")
primingData = None
if style is not None:
priming_text, priming_strokes = load_priming_data(style)
priming_text_tensor, priming_text_len_tensor = text_to_tensor(
priming_text, max_text_length=len(priming_text), add_eos=False)
priming_stroke_tensor = torch.tensor(priming_strokes,
dtype=torch.float32,
device=DEVICE).unsqueeze(dim=0)
primingData = PrimingData(priming_stroke_tensor,
char_seq_tensors=priming_text_tensor,
char_seq_lengths=priming_text_len_tensor)
with torch.inference_mode():
try:
stroke_tensors = SCRIPTED_MODEL.sample(
char_seq,
char_lengths,
max_length=max_gen_len,
bias=api_bias,
prime=primingData
)
# batch_size is 1
if len(stroke_tensors) == 1:
all_strokes_tensor = stroke_tensors[0]
stroke_offsets = all_strokes_tensor.cpu().numpy().tolist()
else:
stroke_offsets = []
logger.warning(f"Expected single batch, but got {len(stroke_tensors)}")
return stroke_offsets
except Exception as e:
logger.error(f"Error in model sampling: {e}", exc_info=True)
return []
@app.post("/generate", response_model=HandwritingResponse, tags=["Generation"])
async def generate_handwriting_endpoint(request: HandwritingRequest):
if not all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN]):
logger.error("API not fully initialized. Check /health endpoint.")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model or required resources not loaded."
)
assert DEVICE is not None, "Device is None inside generate_handwriting"
start_time = time.time()
try:
char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, max_text_length=MAX_TEXT_LEN) # type: ignore
relative_stroke_offsets = generate_strokes(
char_seq_tensor, char_lengths_tensor,
request.max_length,
request.bias,
# style=1 #TODO: style is hardcode since the current version is hosted on cpu
)
if not relative_stroke_offsets:
return HandwritingResponse(
success=False,
input_text=request.text,
strokes=[],
num_points=0,
generation_time_ms=(time.time() - start_time) * 1000,
message="No strokes generated."
)
absolute_stroke_coords = convert_offsets_to_absolute_coords(relative_stroke_offsets)
generation_time_ms = (time.time() - start_time) * 1000
return HandwritingResponse(
input_text=request.text,
strokes=absolute_stroke_coords,
num_points=len(absolute_stroke_coords),
generation_time_ms=generation_time_ms
)
except ValueError as ve:
logger.error(f"ValueError during generation for '{request.text}': {ve}", exc_info=True)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
except Exception as e:
logger.error(f"Unexpected error for '{request.text}': {e}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
if __name__ == "__main__":
import uvicorn
logger.info("Starting Uvicorn server for Scriptify API...")
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, app_dir=".")