Spaces:
Running
Running
from typing import Optional | |
from fastapi import FastAPI, HTTPException, status | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import torch | |
import torch.nn.functional as F | |
from pathlib import Path | |
import logging | |
import time | |
from contextlib import asynccontextmanager | |
from inference_utils import PrimingData, construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map, load_priming_data | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
MODEL_DIR = Path("./packaged_models") | |
QUANTIZED_MODEL_NAME = "model.scripted.quantized.pt" | |
SCRIPTED_MODEL_NAME = "model.scripted.pt" | |
METADATA_MODEL_NAME = "model.pt" | |
SCRIPTED_MODEL: Optional[torch.jit.ScriptModule] = None | |
MODEL_METADATA: Optional[dict] = None | |
DEVICE: Optional[torch.device] = None | |
ALPHABET_MAP: Optional[dict[str, int]] = None | |
ALPHABET_LIST: Optional[list[str]] = None | |
ALPHABET_SIZE: Optional[int] = None | |
MAX_TEXT_LEN: Optional[int] = None | |
output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling | |
lstm_size: Optional[int] = None | |
attention_mixture_components: Optional[int] = None | |
# Patience for early stopping in generate_strokes | |
PATIENCE_PEN_UP_EOS = 15 | |
MIN_MOVEMENT_THRESHOLD = 0.02 | |
class HandwritingRequest(BaseModel): | |
text: str = Field(..., min_length=1, max_length=40, description="Text to generate handwriting for") | |
max_length: int = Field(default=1000, ge=50, le=1200, description="Maximum number of stroke points") | |
bias: float = Field(default=0.75, ge=0.1, le=10.0, description="Sampling bias for generation") | |
class HandwritingResponse(BaseModel): | |
success: bool = True | |
input_text: str | |
generation_time_ms: float | |
num_points: int | |
strokes: list[list[float]] | |
message: str = "Successfully generated handwriting." | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
device: str | |
model_metadata_keys: Optional[list[str]] = None | |
async def lifespan(app: FastAPI): | |
"""Lifespan context manager for startup and shutdown events""" | |
global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE | |
logger.info("Attempting to load model resources during startup") | |
try: | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {DEVICE}") | |
scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME | |
metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME | |
# if DEVICE.type == "cpu": | |
# scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME | |
if not scripted_model_path.exists(): | |
logger.error(f"Traced model not found at {scripted_model_path}") | |
raise FileNotFoundError(f"Traced model not found at {scripted_model_path}") | |
if not metadata_model_path or not metadata_model_path.exists(): | |
logger.error(f"Metadata model file not found at {metadata_model_path}") | |
raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}") | |
# Load the traced model | |
SCRIPTED_MODEL = torch.jit.load(scripted_model_path, map_location=DEVICE) | |
if SCRIPTED_MODEL: | |
SCRIPTED_MODEL.eval() | |
logger.info(f"Traced model loaded successfully from {scripted_model_path}") | |
# Load the metadata | |
MODEL_METADATA = torch.load(metadata_model_path, map_location='cpu') | |
if MODEL_METADATA: | |
logger.info(f"Model metadata loaded successfully from {metadata_model_path}") | |
logger.info(f"Model metadata keys: {list(MODEL_METADATA.keys())}") | |
config_full = MODEL_METADATA['config_full'] | |
if not config_full or not isinstance(config_full, dict): | |
raise ValueError(f"Key `config_full` not found or not a dict") | |
dataset_config = config_full['dataset'] | |
model_params = config_full['model_params'] | |
if not dataset_config or not isinstance(dataset_config, dict): | |
raise ValueError(f"Key `dataset` not found or not a dict in config_full") | |
alphabet_str = dataset_config['alphabet_string'] | |
MAX_TEXT_LEN = dataset_config['max_text_len'] | |
output_mixture_components = model_params['output_mixture_components'] | |
lstm_size = model_params['lstm_size'] | |
attention_mixture_components = model_params['attention_mixture_components'] | |
ALPHABET_LIST = construct_alphabet_list(alphabet_str) | |
ALPHABET_SIZE = len(ALPHABET_LIST) | |
ALPHABET_MAP = get_alphabet_map(ALPHABET_LIST) | |
logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}") | |
logger.info("Model resources are loaded and ready") | |
else: | |
raise ValueError(f"Failed to load content frm metadata file") | |
except Exception as e: | |
logger.error(f"Error loading model resources: {e}", exc_info=True) | |
SCRIPTED_MODEL = None | |
MODEL_METADATA = None | |
raise | |
yield | |
# Cleanup on shutdown | |
logger.info("Shutting down API and cleaning up resources") | |
SCRIPTED_MODEL = None | |
MODEL_METADATA = None | |
app = FastAPI( | |
title="Scriptify API", | |
description="API to generate handwriting from text using a PyTorch model.", | |
version="0.1.0", | |
lifespan=lifespan | |
) | |
# add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:5173","http://127.0.0.1:5173"], | |
allow_credentials=True, | |
allow_methods=["GET", "POST"], | |
allow_headers=["*"], | |
) | |
async def read_root(): | |
return {"message": "Welcome to the Scriptify Handwriting Generation API!"} | |
async def health_check(): | |
global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST | |
is_healthy = all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST]) | |
return HealthResponse( | |
status="healthy" if is_healthy else "unhealthy", | |
model_loaded=bool(SCRIPTED_MODEL), | |
device=str(DEVICE) if DEVICE else "unknown", | |
model_metadata_keys=list(MODEL_METADATA.keys()) if MODEL_METADATA else None, | |
) | |
def text_to_tensor(text: str, max_text_length: int, add_eos: bool = True) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Convert text to tensor format expected by the model""" | |
if ALPHABET_MAP is None: | |
raise ValueError("Alphabet map not initialized during api startup") | |
padded_encoded_np, true_length = encode_text( | |
text=text, | |
char_to_index_map=ALPHABET_MAP, | |
max_length=max_text_length, | |
add_eos = add_eos | |
) | |
char_seq = torch.from_numpy(padded_encoded_np).to(device=DEVICE, dtype=torch.long) | |
char_len = torch.tensor([true_length], device=DEVICE, dtype=torch.long) | |
return char_seq, char_len | |
def generate_strokes( | |
char_seq: torch.Tensor, | |
char_lengths: torch.Tensor, | |
max_gen_len: int, | |
api_bias: float, | |
style: Optional[int] = None | |
) -> list[list[float]]: | |
"""Generate strokes using the model's built-in sample method""" | |
global SCRIPTED_MODEL | |
if SCRIPTED_MODEL is None: | |
raise ValueError("Scripted model not initialized.") | |
primingData = None | |
if style is not None: | |
priming_text, priming_strokes = load_priming_data(style) | |
priming_text_tensor, priming_text_len_tensor = text_to_tensor( | |
priming_text, max_text_length=len(priming_text), add_eos=False) | |
priming_stroke_tensor = torch.tensor(priming_strokes, | |
dtype=torch.float32, | |
device=DEVICE).unsqueeze(dim=0) | |
primingData = PrimingData(priming_stroke_tensor, | |
char_seq_tensors=priming_text_tensor, | |
char_seq_lengths=priming_text_len_tensor) | |
with torch.inference_mode(): | |
try: | |
stroke_tensors = SCRIPTED_MODEL.sample( | |
char_seq, | |
char_lengths, | |
max_length=max_gen_len, | |
bias=api_bias, | |
prime=primingData | |
) | |
# batch_size is 1 | |
if len(stroke_tensors) == 1: | |
all_strokes_tensor = stroke_tensors[0] | |
stroke_offsets = all_strokes_tensor.cpu().numpy().tolist() | |
else: | |
stroke_offsets = [] | |
logger.warning(f"Expected single batch, but got {len(stroke_tensors)}") | |
return stroke_offsets | |
except Exception as e: | |
logger.error(f"Error in model sampling: {e}", exc_info=True) | |
return [] | |
async def generate_handwriting_endpoint(request: HandwritingRequest): | |
if not all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN]): | |
logger.error("API not fully initialized. Check /health endpoint.") | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Model or required resources not loaded." | |
) | |
assert DEVICE is not None, "Device is None inside generate_handwriting" | |
start_time = time.time() | |
try: | |
char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, max_text_length=MAX_TEXT_LEN) # type: ignore | |
relative_stroke_offsets = generate_strokes( | |
char_seq_tensor, char_lengths_tensor, | |
request.max_length, | |
request.bias, | |
# style=1 #TODO: style is hardcode since the current version is hosted on cpu | |
) | |
if not relative_stroke_offsets: | |
return HandwritingResponse( | |
success=False, | |
input_text=request.text, | |
strokes=[], | |
num_points=0, | |
generation_time_ms=(time.time() - start_time) * 1000, | |
message="No strokes generated." | |
) | |
absolute_stroke_coords = convert_offsets_to_absolute_coords(relative_stroke_offsets) | |
generation_time_ms = (time.time() - start_time) * 1000 | |
return HandwritingResponse( | |
input_text=request.text, | |
strokes=absolute_stroke_coords, | |
num_points=len(absolute_stroke_coords), | |
generation_time_ms=generation_time_ms | |
) | |
except ValueError as ve: | |
logger.error(f"ValueError during generation for '{request.text}': {ve}", exc_info=True) | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) | |
except Exception as e: | |
logger.error(f"Unexpected error for '{request.text}': {e}", exc_info=True) | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting Uvicorn server for Scriptify API...") | |
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, app_dir=".") |