File size: 10,646 Bytes
4eb5b6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Simplified DiffSketcher implementation for Hugging Face Inference API.
This version doesn't rely on cloning the repository at runtime.
"""

import os
import io
import base64
import torch
import numpy as np
from PIL import Image
import cairosvg
import random
from pathlib import Path

class SimplifiedDiffSketcher:
    def __init__(self, model_dir):
        """Initialize the simplified DiffSketcher model"""
        self.model_dir = model_dir
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing simplified DiffSketcher on device: {self.device}")
        
        # Load CLIP model if available
        try:
            import clip
            self.clip_model, _ = clip.load("ViT-B-32", device=self.device)
            self.clip_available = True
            print("CLIP model loaded successfully")
        except Exception as e:
            print(f"Error loading CLIP model: {e}")
            self.clip_available = False
    
    def generate_svg(self, prompt, num_paths=20, width=512, height=512):
        """Generate an SVG from a text prompt"""
        print(f"Generating SVG for prompt: {prompt}")
        
        # Use CLIP to encode the prompt if available
        if self.clip_available:
            try:
                import clip
                with torch.no_grad():
                    text = clip.tokenize([prompt]).to(self.device)
                    text_features = self.clip_model.encode_text(text)
                    text_features = text_features.cpu().numpy()[0]
                    # Normalize features
                    text_features = text_features / np.linalg.norm(text_features)
            except Exception as e:
                print(f"Error encoding prompt with CLIP: {e}")
                text_features = np.random.randn(512)  # Random features as fallback
        else:
            # Generate random features if CLIP is not available
            text_features = np.random.randn(512)
        
        # Generate a car-like SVG based on the prompt
        svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height)
        
        return svg_content
    
    def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512):
        """Generate a car-like SVG based on the prompt and features"""
        # Start SVG
        svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
            <rect width="100%" height="100%" fill="#f8f8f8"/>
        """
        
        # Use the features to determine car properties
        car_color_hue = int((features[0] + 1) * 180) % 360  # Map to 0-360 hue
        car_size = 0.6 + 0.2 * features[1]  # Size variation
        car_style = int(abs(features[2] * 3)) % 3  # 0: sedan, 1: SUV, 2: sports car
        
        # Calculate car dimensions
        car_width = int(width * 0.7 * car_size)
        car_height = int(height * 0.3 * car_size)
        car_x = (width - car_width) // 2
        car_y = height // 2
        
        # Generate car body based on style
        if car_style == 0:  # Sedan
            # Car body (rounded rectangle)
            svg_content += f"""<rect x="{car_x}" y="{car_y}" width="{car_width}" height="{car_height}" 
                rx="20" ry="20" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />"""
            
            # Windshield
            windshield_width = car_width * 0.7
            windshield_height = car_height * 0.5
            windshield_x = car_x + (car_width - windshield_width) // 2
            windshield_y = car_y - windshield_height * 0.3
            svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}" 
                rx="10" ry="10" fill="#a8d8ff" stroke="black" stroke-width="1" />"""
            
            # Wheels
            wheel_radius = car_height * 0.4
            wheel_y = car_y + car_height * 0.8
            svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
            svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
            svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
            svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
            
        elif car_style == 1:  # SUV
            # Car body (taller rectangle)
            svg_content += f"""<rect x="{car_x}" y="{car_y - car_height * 0.3}" width="{car_width}" height="{car_height * 1.3}" 
                rx="15" ry="15" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />"""
            
            # Windshield
            windshield_width = car_width * 0.6
            windshield_height = car_height * 0.6
            windshield_x = car_x + (car_width - windshield_width) // 2
            windshield_y = car_y - car_height * 0.2
            svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}" 
                rx="8" ry="8" fill="#a8d8ff" stroke="black" stroke-width="1" />"""
            
            # Wheels (larger)
            wheel_radius = car_height * 0.45
            wheel_y = car_y + car_height * 0.7
            svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
            svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
            svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
            svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
            
        else:  # Sports car
            # Car body (low, sleek shape)
            svg_content += f"""<path d="M {car_x} {car_y + car_height * 0.5} 
                C {car_x + car_width * 0.1} {car_y - car_height * 0.2}, 
                {car_x + car_width * 0.3} {car_y - car_height * 0.3}, 
                {car_x + car_width * 0.5} {car_y - car_height * 0.2} 
                S {car_x + car_width * 0.9} {car_y}, 
                {car_x + car_width} {car_y + car_height * 0.3} 
                L {car_x + car_width} {car_y + car_height * 0.7} 
                C {car_x + car_width * 0.9} {car_y + car_height}, 
                {car_x + car_width * 0.1} {car_y + car_height}, 
                {car_x} {car_y + car_height * 0.7} Z" 
                fill="hsl({car_color_hue}, 90%, 45%)" stroke="black" stroke-width="2" />"""
            
            # Windshield
            windshield_width = car_width * 0.4
            windshield_x = car_x + car_width * 0.3
            windshield_y = car_y - car_height * 0.1
            svg_content += f"""<path d="M {windshield_x} {windshield_y} 
                C {windshield_x + windshield_width * 0.1} {windshield_y - car_height * 0.15}, 
                {windshield_x + windshield_width * 0.9} {windshield_y - car_height * 0.15}, 
                {windshield_x + windshield_width} {windshield_y} Z" 
                fill="#a8d8ff" stroke="black" stroke-width="1" />"""
            
            # Wheels (low profile)
            wheel_radius = car_height * 0.35
            wheel_y = car_y + car_height * 0.7
            svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />"""
            svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />"""
            svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />"""
            svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />"""
        
        # Add headlights
        headlight_radius = car_width * 0.05
        headlight_y = car_y + car_height * 0.3
        svg_content += f"""<circle cx="{car_x + car_width * 0.1}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />"""
        svg_content += f"""<circle cx="{car_x + car_width * 0.9}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />"""
        
        # Add details based on features
        for i in range(min(10, len(features))):
            feature_val = features[i % len(features)]
            x = car_x + car_width * ((i / 10) * 0.8 + 0.1)
            y = car_y + car_height * ((feature_val + 1) / 4)
            size = car_width * 0.03 * abs(feature_val)
            svg_content += f"""<circle cx="{x}" cy="{y}" r="{size}" fill="rgba(0,0,0,0.2)" />"""
        
        # Add prompt as text
        svg_content += f"""<text x="{width/2}" y="{height - 20}" font-family="Arial" font-size="12" text-anchor="middle">{prompt}</text>"""
        
        # Close SVG
        svg_content += "</svg>"
        
        return svg_content
    
    def svg_to_png(self, svg_content):
        """Convert SVG content to PNG"""
        try:
            png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
            return png_data
        except Exception as e:
            print(f"Error converting SVG to PNG: {e}")
            # Create a simple error image
            image = Image.new("RGB", (512, 512), color="#ff0000")
            from PIL import ImageDraw
            draw = ImageDraw.Draw(image)
            draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
            
            # Convert PIL Image to PNG data
            buffer = io.BytesIO()
            image.save(buffer, format="PNG")
            return buffer.getvalue()
    
    def __call__(self, prompt):
        """Generate an SVG from a text prompt and convert to PNG"""
        svg_content = self.generate_svg(prompt)
        png_data = self.svg_to_png(svg_content)
        
        # Create a PIL Image from the PNG data
        image = Image.open(io.BytesIO(png_data))
        
        # Create the response
        response = {
            "svg": svg_content,
            "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"),
            "png_base64": base64.b64encode(png_data).decode("utf-8"),
            "image": image
        }
        
        return response