File size: 11,375 Bytes
a9beef1
 
 
 
 
 
 
 
 
 
a55bf24
a9beef1
 
 
 
70e9cbc
44bb151
70e9cbc
 
a9beef1
a55bf24
 
 
 
a9beef1
 
a55bf24
a9beef1
 
 
 
 
 
 
 
 
 
 
fea5d40
 
a9beef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55bf24
a9beef1
 
a55bf24
 
a9beef1
 
 
91f531c
 
a9beef1
 
 
 
 
 
 
 
 
a55bf24
 
 
a9beef1
 
 
a55bf24
 
a9beef1
a55bf24
a9beef1
a55bf24
a9beef1
 
 
 
 
 
 
 
 
a55bf24
a9beef1
 
 
 
 
 
 
a55bf24
a9beef1
 
 
 
 
 
 
 
a55bf24
 
a9beef1
 
 
 
 
 
a55bf24
 
a9beef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55bf24
a9beef1
a55bf24
a9beef1
 
 
a55bf24
 
 
a9beef1
 
a55bf24
a9beef1
a55bf24
a9beef1
a55bf24
a9beef1
 
a55bf24
 
 
a9beef1
 
a55bf24
 
a9beef1
 
 
 
 
 
 
 
a55bf24
a9beef1
 
a55bf24
 
a9beef1
 
a55bf24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99b28cc
a9beef1
a55bf24
a9beef1
 
 
a55bf24
 
a9beef1
 
f5121bb
 
a9beef1
 
 
 
f5121bb
a9beef1
 
 
 
 
 
 
 
a55bf24
a9beef1
 
 
 
 
 
a55bf24
a9beef1
 
 
a55bf24
a9beef1
 
a55bf24
 
 
26ec32c
a9beef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from typing import Optional
from fastapi import FastAPI, HTTPException, status 
from fastapi.middleware.cors import CORSMiddleware 
from pydantic import BaseModel, Field
import torch 
import torch.nn.functional as F 
from pathlib import Path 
import logging 
import time 
from contextlib import asynccontextmanager
from inference_utils import PrimingData, construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map, load_priming_data

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) 

MODEL_DIR = Path("./packaged_models") 
QUANTIZED_MODEL_NAME = "model.scripted.quantized.pt"
SCRIPTED_MODEL_NAME = "model.scripted.pt" 
METADATA_MODEL_NAME = "model.pt" 

SCRIPTED_MODEL: Optional[torch.jit.ScriptModule] = None 
MODEL_METADATA: Optional[dict] = None 
DEVICE: Optional[torch.device] = None 
ALPHABET_MAP: Optional[dict[str, int]] = None 
ALPHABET_LIST: Optional[list[str]] = None 
ALPHABET_SIZE: Optional[int] = None 
MAX_TEXT_LEN: Optional[int] = None 
output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling
lstm_size: Optional[int] = None 
attention_mixture_components: Optional[int] = None 

# Patience for early stopping in generate_strokes
PATIENCE_PEN_UP_EOS = 15
MIN_MOVEMENT_THRESHOLD = 0.02 


class HandwritingRequest(BaseModel):
    text: str = Field(..., min_length=1, max_length=40, description="Text to generate handwriting for")
    max_length: int = Field(default=1000, ge=50, le=1200, description="Maximum number of stroke points")
    bias: float = Field(default=0.75, ge=0.1, le=10.0, description="Sampling bias for generation")
class HandwritingResponse(BaseModel):
    success: bool = True
    input_text: str
    generation_time_ms: float
    num_points: int
    strokes: list[list[float]]
    message: str = "Successfully generated handwriting."

class HealthResponse(BaseModel):
    status: str 
    model_loaded: bool 
    device: str 
    model_metadata_keys: Optional[list[str]] = None 

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Lifespan context manager for startup and shutdown events"""
    global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE
    logger.info("Attempting to load model resources during startup") 
    try:
        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {DEVICE}")
        
        scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME
        metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME
        # if DEVICE.type == "cpu":
        #     scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME
        
        if  not scripted_model_path.exists():
            logger.error(f"Traced model not found at {scripted_model_path}")
            raise FileNotFoundError(f"Traced model not found at {scripted_model_path}")
        if not metadata_model_path or not metadata_model_path.exists():
            logger.error(f"Metadata model file not found at {metadata_model_path}")
            raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}")

        # Load the traced model
        SCRIPTED_MODEL = torch.jit.load(scripted_model_path, map_location=DEVICE)
        if SCRIPTED_MODEL:
            SCRIPTED_MODEL.eval()
            logger.info(f"Traced model loaded successfully from {scripted_model_path}")

        # Load the metadata
        MODEL_METADATA = torch.load(metadata_model_path, map_location='cpu')
        if MODEL_METADATA:
            logger.info(f"Model metadata loaded successfully from {metadata_model_path}")
            logger.info(f"Model metadata keys: {list(MODEL_METADATA.keys())}")
            
            config_full = MODEL_METADATA['config_full']
            if not config_full or not isinstance(config_full, dict):
                raise ValueError(f"Key `config_full` not found or not a dict")
            
            dataset_config = config_full['dataset']
            model_params = config_full['model_params']

            if not dataset_config or not isinstance(dataset_config, dict):
                raise ValueError(f"Key `dataset` not found or not a dict in config_full")
            alphabet_str = dataset_config['alphabet_string']
            MAX_TEXT_LEN = dataset_config['max_text_len']
            output_mixture_components = model_params['output_mixture_components']

            lstm_size = model_params['lstm_size']
            attention_mixture_components = model_params['attention_mixture_components']

            ALPHABET_LIST = construct_alphabet_list(alphabet_str)
            ALPHABET_SIZE = len(ALPHABET_LIST)
            ALPHABET_MAP = get_alphabet_map(ALPHABET_LIST) 
            
            logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}")
            logger.info("Model resources are loaded and ready")
        else:
            raise ValueError(f"Failed to load content frm metadata file")

    except Exception as e:
        logger.error(f"Error loading model resources: {e}", exc_info=True)
        SCRIPTED_MODEL = None
        MODEL_METADATA = None
        raise
    
    yield
    
    # Cleanup on shutdown
    logger.info("Shutting down API and cleaning up resources")
    SCRIPTED_MODEL = None 
    MODEL_METADATA = None 

app = FastAPI(
    title="Scriptify API", 
    description="API to generate handwriting from text using a PyTorch model.", 
    version="0.1.0",
    lifespan=lifespan
)

# add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173","http://127.0.0.1:5173"],
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

@app.get("/", tags=["General"])
async def read_root():
    return {"message": "Welcome to the Scriptify Handwriting Generation API!"}

@app.get("/health", response_model=HealthResponse, tags=["General"])
async def health_check():
    global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST 
    
    is_healthy = all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST])
    
    return HealthResponse(
        status="healthy" if is_healthy else "unhealthy",
        model_loaded=bool(SCRIPTED_MODEL),
        device=str(DEVICE) if DEVICE else "unknown",
        model_metadata_keys=list(MODEL_METADATA.keys()) if MODEL_METADATA else None,
    )

def text_to_tensor(text: str, max_text_length: int, add_eos: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
    """Convert text to tensor format expected by the model"""
    if ALPHABET_MAP is None:
        raise ValueError("Alphabet map not initialized during api startup")

    padded_encoded_np, true_length = encode_text(
        text=text, 
        char_to_index_map=ALPHABET_MAP, 
        max_length=max_text_length, 
        add_eos = add_eos 
    )

    char_seq = torch.from_numpy(padded_encoded_np).to(device=DEVICE, dtype=torch.long)
    char_len = torch.tensor([true_length], device=DEVICE, dtype=torch.long)

    return char_seq, char_len

def generate_strokes(
    char_seq: torch.Tensor,
    char_lengths: torch.Tensor,
    max_gen_len: int,
    api_bias: float,
    style: Optional[int] = None
) -> list[list[float]]:
    """Generate strokes using the model's built-in sample method"""
    global SCRIPTED_MODEL
    if SCRIPTED_MODEL is None:
        raise ValueError("Scripted model not initialized.")
    
    primingData = None 

    if style is not None: 
        priming_text, priming_strokes = load_priming_data(style) 

        priming_text_tensor, priming_text_len_tensor = text_to_tensor(
            priming_text, max_text_length=len(priming_text), add_eos=False)

        priming_stroke_tensor = torch.tensor(priming_strokes, 
                                             dtype=torch.float32, 
                                             device=DEVICE).unsqueeze(dim=0)
        
        primingData = PrimingData(priming_stroke_tensor,
                                  char_seq_tensors=priming_text_tensor, 
                                  char_seq_lengths=priming_text_len_tensor)
        
    
    with torch.inference_mode():
        try:
            stroke_tensors = SCRIPTED_MODEL.sample(
                char_seq, 
                char_lengths, 
                max_length=max_gen_len, 
                bias=api_bias,
                prime=primingData
            )
            
            # batch_size is 1
            if len(stroke_tensors) == 1:
                all_strokes_tensor = stroke_tensors[0]
                stroke_offsets = all_strokes_tensor.cpu().numpy().tolist()
            else:
                stroke_offsets = []
                logger.warning(f"Expected single batch, but got {len(stroke_tensors)}")
            return stroke_offsets
            
        except Exception as e:
            logger.error(f"Error in model sampling: {e}", exc_info=True)
            return []
            
@app.post("/generate", response_model=HandwritingResponse, tags=["Generation"])
async def generate_handwriting_endpoint(request: HandwritingRequest):
    if not all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN]):
        logger.error("API not fully initialized. Check /health endpoint.")
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Model or required resources not loaded."
        )
    
    assert DEVICE is not None, "Device is None inside generate_handwriting"
    start_time = time.time()
    
    try:
        char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, max_text_length=MAX_TEXT_LEN) # type: ignore 

        relative_stroke_offsets = generate_strokes(
            char_seq_tensor, char_lengths_tensor, 
            request.max_length,
            request.bias,
            # style=1 #TODO: style is hardcode since the current version is hosted on cpu 
        )

        if not relative_stroke_offsets:
            return HandwritingResponse(
                success=False,
                input_text=request.text,
                strokes=[],
                num_points=0,
                generation_time_ms=(time.time() - start_time) * 1000,
                message="No strokes generated."
            )

        absolute_stroke_coords = convert_offsets_to_absolute_coords(relative_stroke_offsets)
        generation_time_ms = (time.time() - start_time) * 1000

        return HandwritingResponse(
            input_text=request.text,
            strokes=absolute_stroke_coords,
            num_points=len(absolute_stroke_coords),
            generation_time_ms=generation_time_ms
        )
    except ValueError as ve:
        logger.error(f"ValueError during generation for '{request.text}': {ve}", exc_info=True)
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
    except Exception as e:
        logger.error(f"Unexpected error for '{request.text}': {e}", exc_info=True)
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")

if __name__ == "__main__":
    import uvicorn
    logger.info("Starting Uvicorn server for Scriptify API...")
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, app_dir=".")