diffsketcher / diffsketcher_handler.py
jree423's picture
Upload diffsketcher_handler.py with huggingface_hub
fdbaec8 verified
raw
history blame
3.06 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import torch
import numpy as np
from PIL import Image
import io
import base64
from handler_template import BaseHandler
# Add DiffSketcher to path
sys.path.append("/app/model")
class Handler(BaseHandler):
def initialize(self):
"""Load the DiffSketcher model"""
try:
from models.clip_text_encoder import CLIPTextEncoder
from models.sketch_generator import SketchGenerator
# Load text encoder
self.text_encoder = CLIPTextEncoder()
self.text_encoder.to(self.device)
self.text_encoder.eval()
# Load sketch generator
self.model = SketchGenerator()
weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth")
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
else:
raise FileNotFoundError(f"Model weights not found at {weights_path}")
self.model.to(self.device)
self.model.eval()
self.initialized = True
print("DiffSketcher model initialized successfully")
except Exception as e:
print(f"Error initializing DiffSketcher model: {str(e)}")
raise
def preprocess(self, data):
"""Process the input data"""
try:
# Extract prompt from the request
prompt = data.get("prompt", "")
if not prompt:
raise ValueError("No prompt provided in the request")
# Encode text with CLIP
with torch.no_grad():
text_embedding = self.text_encoder.encode_text(prompt)
return {
"text_embedding": text_embedding,
"prompt": prompt
}
except Exception as e:
print(f"Error in preprocessing: {str(e)}")
raise
def inference(self, inputs):
"""Generate SVG from text embedding"""
try:
text_embedding = inputs["text_embedding"]
# Run inference
with torch.no_grad():
svg_data = self.model.generate(text_embedding)
return svg_data
except Exception as e:
print(f"Error during inference: {str(e)}")
raise
def postprocess(self, inference_output):
"""Format the model output"""
try:
svg_content = inference_output["svg_content"]
# Return both the SVG content and base64 encoded version
return {
"svg_content": svg_content,
"svg_base64": self.svg_to_base64(svg_content)
}
except Exception as e:
print(f"Error in postprocessing: {str(e)}")
return {"error": str(e)}