File size: 9,106 Bytes
45bde7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
DiffSketcher endpoint implementation for Hugging Face.
"""

import os
import sys
import io
import base64
import torch
import numpy as np
from PIL import Image
import cairosvg
import tempfile
import subprocess
import shutil
from pathlib import Path

class DiffSketcherEndpoint:
    def __init__(self, model_dir):
        """Initialize the DiffSketcher endpoint"""
        self.model_dir = model_dir
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing DiffSketcher endpoint on device: {self.device}")
        
        # Create a temporary directory for the model
        self.temp_dir = tempfile.mkdtemp()
        self.temp_model_dir = Path(self.temp_dir) / "DiffSketcher"
        
        # Clone the repository if it doesn't exist
        if not os.path.exists(self.temp_model_dir):
            print("Cloning DiffSketcher repository...")
            subprocess.run(
                ["git", "clone", "https://github.com/ximinng/DiffSketcher.git", str(self.temp_model_dir)],
                check=True
            )
        
        # Add the repository to the Python path
        sys.path.append(str(self.temp_model_dir.parent))
        
        # Install dependencies
        self._install_dependencies()
        
        # Initialize the model
        self._initialize_model()
    
    def _install_dependencies(self):
        """Install the required dependencies"""
        try:
            # Install diffvg
            print("Installing diffvg...")
            subprocess.run(
                ["pip", "install", "svgwrite", "svgpathtools", "cssutils", "numba", "torch", "torchvision", 
                 "diffusers", "transformers", "accelerate", "xformers", "omegaconf", "einops", "kornia"],
                check=True
            )
            
            # Install CLIP
            print("Installing CLIP...")
            subprocess.run(
                ["pip", "install", "git+https://github.com/openai/CLIP.git"],
                check=True
            )
            
            # Create a mock diffvg module
            diffvg_dir = Path(self.temp_dir) / "diffvg"
            diffvg_dir.mkdir(exist_ok=True)
            with open(diffvg_dir / "__init__.py", "w") as f:
                f.write("""
# Mock diffvg module
import torch

def render(scene, width, height, samples=2, seed=None):
    return torch.zeros((height, width, 4), dtype=torch.float32)

def render_wrt_shapes(scene, shapes, width, height, samples=2, seed=None):
    return torch.zeros((height, width, 4), dtype=torch.float32)

def render_wrt_camera(scene, camera, width, height, samples=2, seed=None):
    return torch.zeros((height, width, 4), dtype=torch.float32)

def imwrite(img, filename, gamma=2.2):
    pass

def save_svg(scene, filename):
    pass

def set_use_gpu(use_gpu):
    pass

def set_print_timing(print_timing):
    pass
""")
            
            # Add the mock diffvg to the Python path
            sys.path.append(str(diffvg_dir.parent))
            
        except Exception as e:
            print(f"Error installing dependencies: {e}")
    
    def _initialize_model(self):
        """Initialize the DiffSketcher model"""
        try:
            # Import the required modules
            from DiffSketcher.methods.painter.diffsketcher import Painter
            from DiffSketcher.methods.diffusers_warp import init_diffusion_pipeline
            
            # Initialize the model
            self.model_initialized = True
            print("DiffSketcher model initialized successfully")
        except Exception as e:
            print(f"Error initializing DiffSketcher model: {e}")
            self.model_initialized = False
    
    def generate_svg(self, prompt, num_paths=10, width=512, height=512):
        """Generate an SVG from a text prompt"""
        print(f"Generating SVG for prompt: {prompt}")
        
        try:
            # Create a temporary directory for the output
            output_dir = Path(tempfile.mkdtemp())
            
            # Create a config file
            config_path = output_dir / "config.yaml"
            with open(config_path, "w") as f:
                f.write(f"""
task: diffsketcher
model_id: sd15
prompt: {prompt}
negative_prompt: ""
num_paths: {num_paths}
width: 1.5
image_size: {width}
num_iter: 500
lr: 1.0
sds:
  warmup: 0
  grad_scale: 1.0
  t_range: [0.02, 0.98]
  guidance_scale: 7.5
""")
            
            # Run the DiffSketcher script
            if self.model_initialized:
                # Use the actual model
                try:
                    # Import the required modules
                    from DiffSketcher.run_painterly_render import main
                    from DiffSketcher.libs.engine import merge_and_update_config
                    from omegaconf import OmegaConf
                    
                    # Create a mock args object
                    args = OmegaConf.create({
                        "task": "diffsketcher",
                        "config": str(config_path),
                        "prompt": prompt,
                        "negative_prompt": "",
                        "num_paths": num_paths,
                        "width": 1.5,
                        "image_size": width,
                        "num_iter": 500,
                        "lr": 1.0,
                        "sds": {
                            "warmup": 0,
                            "grad_scale": 1.0,
                            "t_range": [0.02, 0.98],
                            "guidance_scale": 7.5
                        },
                        "seed": 42,
                        "batch_size": 1,
                        "render_batch": False,
                        "make_video": False,
                        "print_timing": False,
                        "download": True,
                        "force_download": False,
                        "resume_download": False
                    })
                    
                    # Run the model
                    args = merge_and_update_config(args)
                    main(args, None)
                    
                    # Find the generated SVG
                    svg_files = list(output_dir.glob("**/*.svg"))
                    if svg_files:
                        with open(svg_files[0], "r") as f:
                            svg_content = f.read()
                    else:
                        raise FileNotFoundError("No SVG file generated")
                    
                except Exception as e:
                    print(f"Error running DiffSketcher model: {e}")
                    # Fall back to placeholder
                    svg_content = self._generate_placeholder_svg(prompt, width, height)
            else:
                # Use a placeholder
                svg_content = self._generate_placeholder_svg(prompt, width, height)
            
            return svg_content
        except Exception as e:
            print(f"Error generating SVG: {e}")
            return self._generate_placeholder_svg(prompt, width, height)
    
    def _generate_placeholder_svg(self, prompt, width=512, height=512):
        """Generate a placeholder SVG"""
        svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
            <rect width="100%" height="100%" fill="#f0f0f0"/>
            <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
        </svg>"""
        return svg_content
    
    def svg_to_png(self, svg_content):
        """Convert SVG content to PNG"""
        try:
            png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
            return png_data
        except Exception as e:
            print(f"Error converting SVG to PNG: {e}")
            # Create a simple error image
            image = Image.new("RGB", (512, 512), color="#ff0000")
            from PIL import ImageDraw
            draw = ImageDraw.Draw(image)
            draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
            
            # Convert PIL Image to PNG data
            buffer = io.BytesIO()
            image.save(buffer, format="PNG")
            return buffer.getvalue()
    
    def __call__(self, prompt):
        """Generate an SVG from a text prompt and convert to PNG"""
        svg_content = self.generate_svg(prompt)
        png_data = self.svg_to_png(svg_content)
        
        # Create a PIL Image from the PNG data
        image = Image.open(io.BytesIO(png_data))
        
        # Create the response
        response = {
            "svg": svg_content,
            "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"),
            "png_base64": base64.b64encode(png_data).decode("utf-8"),
            "image": image
        }
        
        return response
    
    def __del__(self):
        """Clean up temporary files"""
        if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
            shutil.rmtree(self.temp_dir)