File size: 3,060 Bytes
fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 4039872 fdbaec8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import torch
import numpy as np
from PIL import Image
import io
import base64
from handler_template import BaseHandler
# Add DiffSketcher to path
sys.path.append("/app/model")
class Handler(BaseHandler):
def initialize(self):
"""Load the DiffSketcher model"""
try:
from models.clip_text_encoder import CLIPTextEncoder
from models.sketch_generator import SketchGenerator
# Load text encoder
self.text_encoder = CLIPTextEncoder()
self.text_encoder.to(self.device)
self.text_encoder.eval()
# Load sketch generator
self.model = SketchGenerator()
weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth")
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
else:
raise FileNotFoundError(f"Model weights not found at {weights_path}")
self.model.to(self.device)
self.model.eval()
self.initialized = True
print("DiffSketcher model initialized successfully")
except Exception as e:
print(f"Error initializing DiffSketcher model: {str(e)}")
raise
def preprocess(self, data):
"""Process the input data"""
try:
# Extract prompt from the request
prompt = data.get("prompt", "")
if not prompt:
raise ValueError("No prompt provided in the request")
# Encode text with CLIP
with torch.no_grad():
text_embedding = self.text_encoder.encode_text(prompt)
return {
"text_embedding": text_embedding,
"prompt": prompt
}
except Exception as e:
print(f"Error in preprocessing: {str(e)}")
raise
def inference(self, inputs):
"""Generate SVG from text embedding"""
try:
text_embedding = inputs["text_embedding"]
# Run inference
with torch.no_grad():
svg_data = self.model.generate(text_embedding)
return svg_data
except Exception as e:
print(f"Error during inference: {str(e)}")
raise
def postprocess(self, inference_output):
"""Format the model output"""
try:
svg_content = inference_output["svg_content"]
# Return both the SVG content and base64 encoded version
return {
"svg_content": svg_content,
"svg_base64": self.svg_to_base64(svg_content)
}
except Exception as e:
print(f"Error in postprocessing: {str(e)}")
return {"error": str(e)} |