|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
from handler_template import BaseHandler |
|
|
|
|
|
|
|
|
sys.path.append("/app/model") |
|
|
|
|
|
class Handler(BaseHandler): |
|
|
def initialize(self): |
|
|
"""Load the DiffSketcher model""" |
|
|
try: |
|
|
from models.clip_text_encoder import CLIPTextEncoder |
|
|
from models.sketch_generator import SketchGenerator |
|
|
|
|
|
|
|
|
self.text_encoder = CLIPTextEncoder() |
|
|
self.text_encoder.to(self.device) |
|
|
self.text_encoder.eval() |
|
|
|
|
|
|
|
|
self.model = SketchGenerator() |
|
|
weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth") |
|
|
if os.path.exists(weights_path): |
|
|
state_dict = torch.load(weights_path, map_location=self.device) |
|
|
self.model.load_state_dict(state_dict) |
|
|
else: |
|
|
raise FileNotFoundError(f"Model weights not found at {weights_path}") |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.initialized = True |
|
|
print("DiffSketcher model initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"Error initializing DiffSketcher model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def preprocess(self, data): |
|
|
"""Process the input data""" |
|
|
try: |
|
|
|
|
|
prompt = data.get("prompt", "") |
|
|
if not prompt: |
|
|
raise ValueError("No prompt provided in the request") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
text_embedding = self.text_encoder.encode_text(prompt) |
|
|
|
|
|
return { |
|
|
"text_embedding": text_embedding, |
|
|
"prompt": prompt |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error in preprocessing: {str(e)}") |
|
|
raise |
|
|
|
|
|
def inference(self, inputs): |
|
|
"""Generate SVG from text embedding""" |
|
|
try: |
|
|
text_embedding = inputs["text_embedding"] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
svg_data = self.model.generate(text_embedding) |
|
|
|
|
|
return svg_data |
|
|
except Exception as e: |
|
|
print(f"Error during inference: {str(e)}") |
|
|
raise |
|
|
|
|
|
def postprocess(self, inference_output): |
|
|
"""Format the model output""" |
|
|
try: |
|
|
svg_content = inference_output["svg_content"] |
|
|
|
|
|
|
|
|
return { |
|
|
"svg_content": svg_content, |
|
|
"svg_base64": self.svg_to_base64(svg_content) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error in postprocessing: {str(e)}") |
|
|
return {"error": str(e)} |