File size: 3,060 Bytes
fdbaec8
 
 
4039872
fdbaec8
4039872
 
fdbaec8
 
 
 
4039872
fdbaec8
 
 
 
 
 
4039872
fdbaec8
 
 
 
 
 
 
4039872
fdbaec8
 
 
 
 
 
 
 
 
 
 
4039872
fdbaec8
 
4039872
fdbaec8
 
4039872
 
fdbaec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4039872
 
fdbaec8
 
 
 
 
 
 
 
 
 
 
 
4039872
 
fdbaec8
 
 
4039872
fdbaec8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import sys
import torch
import numpy as np
from PIL import Image
import io
import base64
from handler_template import BaseHandler

# Add DiffSketcher to path
sys.path.append("/app/model")

class Handler(BaseHandler):
    def initialize(self):
        """Load the DiffSketcher model"""
        try:
            from models.clip_text_encoder import CLIPTextEncoder
            from models.sketch_generator import SketchGenerator
            
            # Load text encoder
            self.text_encoder = CLIPTextEncoder()
            self.text_encoder.to(self.device)
            self.text_encoder.eval()
            
            # Load sketch generator
            self.model = SketchGenerator()
            weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth")
            if os.path.exists(weights_path):
                state_dict = torch.load(weights_path, map_location=self.device)
                self.model.load_state_dict(state_dict)
            else:
                raise FileNotFoundError(f"Model weights not found at {weights_path}")
                
            self.model.to(self.device)
            self.model.eval()
            
            self.initialized = True
            print("DiffSketcher model initialized successfully")
        except Exception as e:
            print(f"Error initializing DiffSketcher model: {str(e)}")
            raise
    
    def preprocess(self, data):
        """Process the input data"""
        try:
            # Extract prompt from the request
            prompt = data.get("prompt", "")
            if not prompt:
                raise ValueError("No prompt provided in the request")
                
            # Encode text with CLIP
            with torch.no_grad():
                text_embedding = self.text_encoder.encode_text(prompt)
                
            return {
                "text_embedding": text_embedding,
                "prompt": prompt
            }
        except Exception as e:
            print(f"Error in preprocessing: {str(e)}")
            raise
    
    def inference(self, inputs):
        """Generate SVG from text embedding"""
        try:
            text_embedding = inputs["text_embedding"]
            
            # Run inference
            with torch.no_grad():
                svg_data = self.model.generate(text_embedding)
                
            return svg_data
        except Exception as e:
            print(f"Error during inference: {str(e)}")
            raise
    
    def postprocess(self, inference_output):
        """Format the model output"""
        try:
            svg_content = inference_output["svg_content"]
            
            # Return both the SVG content and base64 encoded version
            return {
                "svg_content": svg_content,
                "svg_base64": self.svg_to_base64(svg_content)
            }
        except Exception as e:
            print(f"Error in postprocessing: {str(e)}")
            return {"error": str(e)}