|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import json |
|
import base64 |
|
import io |
|
from PIL import Image |
|
import svgwrite |
|
from typing import Dict, Any, List, Optional, Union |
|
import diffusers |
|
from diffusers import StableDiffusionPipeline, DDIMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import torchvision.transforms as transforms |
|
from torchvision.transforms.functional import to_pil_image |
|
import random |
|
import math |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model_id = "runwayml/stable-diffusion-v1-5" |
|
|
|
try: |
|
|
|
self.pipe = StableDiffusionPipeline.from_pretrained( |
|
self.model_id, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
).to(self.device) |
|
|
|
|
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
|
|
|
|
self.clip_model = self.pipe.text_encoder |
|
self.clip_tokenizer = self.pipe.tokenizer |
|
|
|
print("DiffSketcher handler initialized successfully!") |
|
except Exception as e: |
|
print(f"Warning: Could not load diffusion model: {e}") |
|
self.pipe = None |
|
self.clip_model = None |
|
self.clip_tokenizer = None |
|
|
|
def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: |
|
""" |
|
Generate SVG sketch from text prompt using DiffSketcher approach |
|
""" |
|
try: |
|
|
|
if isinstance(inputs, str): |
|
prompt = inputs |
|
parameters = {} |
|
else: |
|
prompt = inputs.get("inputs", inputs.get("prompt", "a simple sketch")) |
|
parameters = inputs.get("parameters", {}) |
|
|
|
|
|
num_paths = parameters.get("num_paths", 64) |
|
num_iter = parameters.get("num_iter", 500) |
|
width = parameters.get("width", 224) |
|
height = parameters.get("height", 224) |
|
guidance_scale = parameters.get("guidance_scale", 7.5) |
|
seed = parameters.get("seed", None) |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
print(f"Generating sketch for: '{prompt}' with {num_paths} paths") |
|
|
|
|
|
svg_content, metadata = self.generate_diffsketcher_svg( |
|
prompt, width, height, num_paths, num_iter, guidance_scale |
|
) |
|
|
|
|
|
pil_image = self.svg_to_pil_image(svg_content, width, height) |
|
|
|
|
|
pil_image.info['svg_content'] = svg_content |
|
pil_image.info['prompt'] = prompt |
|
pil_image.info['parameters'] = json.dumps(parameters) |
|
pil_image.info['num_paths'] = str(num_paths) |
|
pil_image.info['method'] = 'diffsketcher' |
|
|
|
return pil_image |
|
|
|
except Exception as e: |
|
print(f"Error in DiffSketcher handler: {e}") |
|
|
|
fallback_svg = self.create_fallback_svg(prompt if 'prompt' in locals() else "error", 224, 224) |
|
fallback_image = self.svg_to_pil_image(fallback_svg, 224, 224) |
|
fallback_image.info['error'] = str(e) |
|
return fallback_image |
|
|
|
def generate_diffsketcher_svg(self, prompt: str, width: int, height: int, |
|
num_paths: int, num_iter: int, guidance_scale: float): |
|
""" |
|
Generate SVG using DiffSketcher-inspired approach with diffusion guidance |
|
""" |
|
|
|
text_embeddings = self.get_text_embeddings(prompt) |
|
|
|
|
|
paths = self.initialize_paths(num_paths, width, height) |
|
|
|
|
|
optimized_paths = self.optimize_paths_with_diffusion( |
|
paths, text_embeddings, prompt, width, height, num_iter, guidance_scale |
|
) |
|
|
|
|
|
svg_content = self.paths_to_svg(optimized_paths, width, height) |
|
|
|
metadata = { |
|
"method": "diffsketcher", |
|
"prompt": prompt, |
|
"num_paths": num_paths, |
|
"num_iter": num_iter, |
|
"guidance_scale": guidance_scale, |
|
"width": width, |
|
"height": height |
|
} |
|
|
|
return svg_content, metadata |
|
|
|
def get_text_embeddings(self, prompt: str): |
|
"""Get CLIP text embeddings for the prompt""" |
|
if self.clip_model is None or self.clip_tokenizer is None: |
|
|
|
return torch.zeros((2, 77, 768)) |
|
|
|
try: |
|
with torch.no_grad(): |
|
text_inputs = self.clip_tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.clip_tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
text_embeddings = self.clip_model(text_inputs.input_ids)[0] |
|
|
|
|
|
uncond_inputs = self.clip_tokenizer( |
|
"", |
|
padding="max_length", |
|
max_length=self.clip_tokenizer.model_max_length, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
uncond_embeddings = self.clip_model(uncond_inputs.input_ids)[0] |
|
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
return text_embeddings |
|
except Exception as e: |
|
print(f"Error getting text embeddings: {e}") |
|
return torch.zeros((2, 77, 768)) |
|
|
|
def initialize_paths(self, num_paths: int, width: int, height: int): |
|
"""Initialize random Bezier paths""" |
|
paths = [] |
|
|
|
for i in range(num_paths): |
|
|
|
start_x = random.uniform(0.1 * width, 0.9 * width) |
|
start_y = random.uniform(0.1 * height, 0.9 * height) |
|
|
|
|
|
cp1_x = start_x + random.uniform(-width*0.2, width*0.2) |
|
cp1_y = start_y + random.uniform(-height*0.2, height*0.2) |
|
cp2_x = start_x + random.uniform(-width*0.2, width*0.2) |
|
cp2_y = start_y + random.uniform(-height*0.2, height*0.2) |
|
|
|
|
|
end_x = start_x + random.uniform(-width*0.3, width*0.3) |
|
end_y = start_y + random.uniform(-height*0.3, height*0.3) |
|
|
|
|
|
cp1_x = max(0, min(width, cp1_x)) |
|
cp1_y = max(0, min(height, cp1_y)) |
|
cp2_x = max(0, min(width, cp2_x)) |
|
cp2_y = max(0, min(height, cp2_y)) |
|
end_x = max(0, min(width, end_x)) |
|
end_y = max(0, min(height, end_y)) |
|
|
|
|
|
color_intensity = random.uniform(0.1, 0.7) |
|
color = ( |
|
int(color_intensity * 255), |
|
int(color_intensity * 255), |
|
int(color_intensity * 255) |
|
) |
|
|
|
|
|
stroke_width = random.uniform(0.5, 3.0) |
|
|
|
path = { |
|
'start': (start_x, start_y), |
|
'cp1': (cp1_x, cp1_y), |
|
'cp2': (cp2_x, cp2_y), |
|
'end': (end_x, end_y), |
|
'color': color, |
|
'stroke_width': stroke_width, |
|
'opacity': random.uniform(0.3, 0.8) |
|
} |
|
paths.append(path) |
|
|
|
return paths |
|
|
|
def optimize_paths_with_diffusion(self, paths: List[Dict], text_embeddings: torch.Tensor, |
|
prompt: str, width: int, height: int, |
|
num_iter: int, guidance_scale: float): |
|
""" |
|
Optimize paths using diffusion model guidance (simplified approach) |
|
""" |
|
|
|
semantic_features = self.extract_semantic_features(prompt) |
|
|
|
|
|
for iteration in range(min(num_iter // 10, 50)): |
|
|
|
paths = self.apply_semantic_guidance(paths, semantic_features, width, height) |
|
|
|
|
|
if iteration % 5 == 0: |
|
paths = self.apply_aesthetic_refinement(paths, width, height) |
|
|
|
return paths |
|
|
|
def extract_semantic_features(self, prompt: str): |
|
"""Extract semantic features from prompt to guide path generation""" |
|
|
|
features = { |
|
'complexity': 'medium', |
|
'style': 'sketch', |
|
'density': 'medium', |
|
'organic': False, |
|
'geometric': False, |
|
'detailed': False |
|
} |
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
complex_words = ['detailed', 'intricate', 'complex', 'elaborate'] |
|
simple_words = ['simple', 'minimal', 'basic', 'clean'] |
|
|
|
if any(word in prompt_lower for word in complex_words): |
|
features['complexity'] = 'high' |
|
features['detailed'] = True |
|
elif any(word in prompt_lower for word in simple_words): |
|
features['complexity'] = 'low' |
|
|
|
|
|
if any(word in prompt_lower for word in ['sketch', 'drawing', 'pencil', 'charcoal']): |
|
features['style'] = 'sketch' |
|
elif any(word in prompt_lower for word in ['painting', 'artistic', 'painted']): |
|
features['style'] = 'artistic' |
|
|
|
|
|
organic_words = ['tree', 'flower', 'animal', 'person', 'face', 'natural', 'organic'] |
|
geometric_words = ['building', 'house', 'geometric', 'square', 'circle', 'triangle'] |
|
|
|
if any(word in prompt_lower for word in organic_words): |
|
features['organic'] = True |
|
if any(word in prompt_lower for word in geometric_words): |
|
features['geometric'] = True |
|
|
|
return features |
|
|
|
def apply_semantic_guidance(self, paths: List[Dict], features: Dict, width: int, height: int): |
|
"""Apply semantic guidance to modify paths""" |
|
modified_paths = [] |
|
|
|
for path in paths: |
|
new_path = path.copy() |
|
|
|
|
|
if features['complexity'] == 'high': |
|
|
|
variation = 0.15 |
|
new_path['cp1'] = ( |
|
new_path['cp1'][0] + random.uniform(-width*variation, width*variation), |
|
new_path['cp1'][1] + random.uniform(-height*variation, height*variation) |
|
) |
|
new_path['cp2'] = ( |
|
new_path['cp2'][0] + random.uniform(-width*variation, width*variation), |
|
new_path['cp2'][1] + random.uniform(-height*variation, height*variation) |
|
) |
|
elif features['complexity'] == 'low': |
|
|
|
start_x, start_y = new_path['start'] |
|
end_x, end_y = new_path['end'] |
|
new_path['cp1'] = ( |
|
start_x + (end_x - start_x) * 0.33, |
|
start_y + (end_y - start_y) * 0.33 |
|
) |
|
new_path['cp2'] = ( |
|
start_x + (end_x - start_x) * 0.66, |
|
start_y + (end_y - start_y) * 0.66 |
|
) |
|
|
|
|
|
if features['organic']: |
|
|
|
new_path['stroke_width'] *= random.uniform(0.8, 1.2) |
|
new_path['opacity'] *= random.uniform(0.9, 1.1) |
|
elif features['geometric']: |
|
|
|
|
|
grid_size = 20 |
|
for key in ['start', 'cp1', 'cp2', 'end']: |
|
x, y = new_path[key] |
|
new_path[key] = ( |
|
round(x / grid_size) * grid_size, |
|
round(y / grid_size) * grid_size |
|
) |
|
|
|
|
|
for key in ['start', 'cp1', 'cp2', 'end']: |
|
x, y = new_path[key] |
|
new_path[key] = ( |
|
max(0, min(width, x)), |
|
max(0, min(height, y)) |
|
) |
|
|
|
modified_paths.append(new_path) |
|
|
|
return modified_paths |
|
|
|
def apply_aesthetic_refinement(self, paths: List[Dict], width: int, height: int): |
|
"""Apply aesthetic refinements to improve visual quality""" |
|
|
|
center_x, center_y = width / 2, height / 2 |
|
|
|
def distance_from_center(path): |
|
start_x, start_y = path['start'] |
|
return math.sqrt((start_x - center_x)**2 + (start_y - center_y)**2) |
|
|
|
|
|
paths.sort(key=distance_from_center, reverse=True) |
|
|
|
|
|
for i, path in enumerate(paths): |
|
|
|
layer_factor = 1.0 - (i / len(paths)) * 0.3 |
|
path['opacity'] = min(0.9, path['opacity'] * layer_factor) |
|
|
|
return paths |
|
|
|
def paths_to_svg(self, paths: List[Dict], width: int, height: int): |
|
"""Convert optimized paths to SVG format""" |
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
for path in paths: |
|
start_x, start_y = path['start'] |
|
cp1_x, cp1_y = path['cp1'] |
|
cp2_x, cp2_y = path['cp2'] |
|
end_x, end_y = path['end'] |
|
|
|
|
|
path_data = f"M {start_x},{start_y} C {cp1_x},{cp1_y} {cp2_x},{cp2_y} {end_x},{end_y}" |
|
|
|
color = path['color'] |
|
stroke_color = f"rgb({color[0]},{color[1]},{color[2]})" |
|
|
|
dwg.add(dwg.path( |
|
d=path_data, |
|
stroke=stroke_color, |
|
stroke_width=path['stroke_width'], |
|
stroke_opacity=path['opacity'], |
|
fill='none', |
|
stroke_linecap='round', |
|
stroke_linejoin='round' |
|
)) |
|
|
|
return dwg.tostring() |
|
|
|
def svg_to_pil_image(self, svg_content: str, width: int, height: int): |
|
"""Convert SVG content to PIL Image""" |
|
try: |
|
import cairosvg |
|
|
|
|
|
png_bytes = cairosvg.svg2png( |
|
bytestring=svg_content.encode('utf-8'), |
|
output_width=width, |
|
output_height=height |
|
) |
|
|
|
|
|
image = Image.open(io.BytesIO(png_bytes)).convert('RGB') |
|
return image |
|
|
|
except ImportError: |
|
print("cairosvg not available, creating simple image representation") |
|
|
|
image = Image.new('RGB', (width, height), 'white') |
|
return image |
|
except Exception as e: |
|
print(f"Error converting SVG to image: {e}") |
|
|
|
image = Image.new('RGB', (width, height), 'white') |
|
return image |
|
|
|
def create_fallback_svg(self, prompt: str, width: int, height: int): |
|
"""Create simple fallback SVG""" |
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
dwg.add(dwg.text( |
|
f"DiffSketcher\n{prompt[:30]}...", |
|
insert=(width/2, height/2), |
|
text_anchor="middle", |
|
font_size="12px", |
|
fill="black" |
|
)) |
|
|
|
return dwg.tostring() |