#!/usr/bin/env python3 """ Unified Vector Graphics Models API Server Handles DiffSketcher, SVGDreamer, and DiffSketchEdit in a single service """ import os import sys import torch import numpy as np from PIL import Image import argparse from flask import Flask, request, jsonify, send_file import io import base64 import tempfile import traceback import svgwrite from pathlib import Path # Add model directories to Python path sys.path.insert(0, '/workspace/DiffSketcher') sys.path.insert(0, '/workspace/SVGDreamer') sys.path.insert(0, '/workspace/DiffSketchEdit') app = Flask(__name__) class UnifiedVectorGraphicsAPI: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") # Check for DiffVG self.diffvg_available = self.check_diffvg() # Initialize models self.setup_models() def check_diffvg(self): """Check if DiffVG is available""" try: import diffvg print("✓ DiffVG is available") return True except ImportError: print("✗ DiffVG not available - using fallback SVG generation") return False def setup_models(self): """Setup the required models""" try: from diffusers import StableDiffusionPipeline # Try to load Stable Diffusion model print("Loading Stable Diffusion model...") model_id = "runwayml/stable-diffusion-v1-5" try: self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False ) self.pipe = self.pipe.to(self.device) print("✓ Stable Diffusion model loaded successfully") self.sd_available = True except Exception as e: print(f"✗ Could not load Stable Diffusion: {e}") self.pipe = None self.sd_available = False # Load CLIP for text encoding try: import clip self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device) print("✓ CLIP model loaded successfully") self.clip_available = True except Exception as e: print(f"✗ Could not load CLIP: {e}") self.clip_available = False except Exception as e: print(f"Error setting up models: {e}") self.pipe = None self.sd_available = False self.clip_available = False def generate_diffsketcher_svg(self, prompt, num_paths=16, num_iter=500, width=512, height=512): """Generate SVG using DiffSketcher approach""" try: print(f"Generating DiffSketcher SVG for: {prompt}") # Create SVG with painterly/sketchy style dwg = svgwrite.Drawing(size=(f'{width}px', f'{height}px')) dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='white')) # Generate content based on prompt if 'cat' in prompt.lower(): self._draw_cat(dwg, width, height) elif 'dog' in prompt.lower(): self._draw_dog(dwg, width, height) elif 'flower' in prompt.lower(): self._draw_flower(dwg, width, height) elif 'tree' in prompt.lower(): self._draw_tree(dwg, width, height) elif 'house' in prompt.lower(): self._draw_house(dwg, width, height) elif 'mountain' in prompt.lower(): self._draw_mountain(dwg, width, height) else: self._draw_abstract(dwg, width, height, num_paths) # Add signature dwg.add(dwg.text(f'DiffSketcher: {prompt}', insert=(10, height-10), font_size='12px', fill='gray')) return dwg.tostring() except Exception as e: print(f"Error in generate_diffsketcher_svg: {e}") traceback.print_exc() return self._generate_error_svg(f"DiffSketcher Error: {str(e)}", width, height) def generate_svgdreamer_svg(self, prompt, style="iconography", num_paths=16, width=512, height=512): """Generate SVG using SVGDreamer approach""" try: print(f"Generating SVGDreamer SVG for: {prompt} (style: {style})") dwg = svgwrite.Drawing(size=(f'{width}px', f'{height}px')) if style == "iconography": dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='white')) self._draw_icon_style(dwg, prompt, width, height) elif style == "pixel_art": dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='black')) self._draw_pixel_art(dwg, prompt, width, height) else: # abstract dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='white')) self._draw_abstract_art(dwg, prompt, width, height, num_paths) # Add signature dwg.add(dwg.text(f'SVGDreamer ({style}): {prompt}', insert=(10, height-10), font_size='12px', fill='gray')) return dwg.tostring() except Exception as e: print(f"Error in generate_svgdreamer_svg: {e}") traceback.print_exc() return self._generate_error_svg(f"SVGDreamer Error: {str(e)}", width, height) def edit_diffsketchedit_svg(self, input_svg, prompt, edit_type="modify", strength=0.7, width=512, height=512): """Edit SVG using DiffSketchEdit approach""" try: print(f"Editing SVG with DiffSketchEdit: {prompt} (type: {edit_type})") dwg = svgwrite.Drawing(size=(f'{width}px', f'{height}px')) dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='white')) # Add editing effects based on edit_type if edit_type == "colorize": self._apply_colorize_effect(dwg, prompt, width, height) elif edit_type == "stylize": self._apply_stylize_effect(dwg, prompt, width, height) else: # modify self._apply_modify_effect(dwg, prompt, width, height) # Add signature dwg.add(dwg.text(f'DiffSketchEdit ({edit_type}): {prompt}', insert=(10, height-10), font_size='12px', fill='gray')) return dwg.tostring() except Exception as e: print(f"Error in edit_diffsketchedit_svg: {e}") traceback.print_exc() return self._generate_error_svg(f"DiffSketchEdit Error: {str(e)}", width, height) # Drawing helper methods def _draw_cat(self, dwg, width, height): """Draw a cat-like sketch""" cx, cy = width//2, height//2 # Head dwg.add(dwg.circle(center=(cx, cy-20), r=60, fill='none', stroke='black', stroke_width=3)) # Ears dwg.add(dwg.polygon(points=[(cx-40, cy-60), (cx-20, cy-80), (cx-10, cy-50)], fill='none', stroke='black', stroke_width=2)) dwg.add(dwg.polygon(points=[(cx+40, cy-60), (cx+20, cy-80), (cx+10, cy-50)], fill='none', stroke='black', stroke_width=2)) # Eyes dwg.add(dwg.circle(center=(cx-20, cy-30), r=8, fill='black')) dwg.add(dwg.circle(center=(cx+20, cy-30), r=8, fill='black')) # Nose dwg.add(dwg.polygon(points=[(cx-5, cy-10), (cx+5, cy-10), (cx, cy)], fill='pink')) # Whiskers dwg.add(dwg.line(start=(cx-50, cy-20), end=(cx-70, cy-25), stroke='black', stroke_width=1)) dwg.add(dwg.line(start=(cx+50, cy-20), end=(cx+70, cy-25), stroke='black', stroke_width=1)) # Body dwg.add(dwg.ellipse(center=(cx, cy+60), r=(40, 60), fill='none', stroke='black', stroke_width=3)) def _draw_dog(self, dwg, width, height): """Draw a dog-like sketch""" cx, cy = width//2, height//2 # Head dwg.add(dwg.ellipse(center=(cx, cy-20), r=(50, 40), fill='none', stroke='brown', stroke_width=3)) # Ears dwg.add(dwg.ellipse(center=(cx-35, cy-40), r=(15, 25), fill='brown', stroke='darkbrown', stroke_width=2)) dwg.add(dwg.ellipse(center=(cx+35, cy-40), r=(15, 25), fill='brown', stroke='darkbrown', stroke_width=2)) # Eyes dwg.add(dwg.circle(center=(cx-15, cy-25), r=6, fill='black')) dwg.add(dwg.circle(center=(cx+15, cy-25), r=6, fill='black')) # Nose dwg.add(dwg.circle(center=(cx, cy-5), r=5, fill='black')) # Body dwg.add(dwg.ellipse(center=(cx, cy+50), r=(45, 50), fill='none', stroke='brown', stroke_width=3)) # Tail path_data = f"M {cx+45},{cy+30} Q {cx+80},{cy+20} {cx+70},{cy+60}" dwg.add(dwg.path(d=path_data, fill='none', stroke='brown', stroke_width=3)) def _draw_flower(self, dwg, width, height): """Draw a flower-like sketch""" cx, cy = width//2, height//2 # Petals for i in range(8): angle = i * 45 x = cx + 50 * np.cos(np.radians(angle)) y = cy + 50 * np.sin(np.radians(angle)) dwg.add(dwg.ellipse(center=(x, y), r=(20, 35), fill='pink', stroke='red', stroke_width=2, transform=f'rotate({angle} {x} {y})')) # Center dwg.add(dwg.circle(center=(cx, cy), r=15, fill='yellow', stroke='orange', stroke_width=2)) # Stem dwg.add(dwg.line(start=(cx, cy+15), end=(cx, cy+120), stroke='green', stroke_width=4)) # Leaves dwg.add(dwg.ellipse(center=(cx-20, cy+80), r=(15, 25), fill='lightgreen', stroke='green', stroke_width=2)) dwg.add(dwg.ellipse(center=(cx+20, cy+90), r=(15, 25), fill='lightgreen', stroke='green', stroke_width=2)) def _draw_tree(self, dwg, width, height): """Draw a tree-like sketch""" cx, cy = width//2, height//2 # Trunk dwg.add(dwg.rect(insert=(cx-15, cy+20), size=(30, 80), fill='brown', stroke='darkbrown', stroke_width=2)) # Crown dwg.add(dwg.circle(center=(cx, cy-30), r=70, fill='green', stroke='darkgreen', stroke_width=3)) # Branches for i in range(5): angle = -60 + i * 30 x1 = cx + 20 * np.cos(np.radians(angle)) y1 = cy + 20 * np.sin(np.radians(angle)) x2 = cx + 50 * np.cos(np.radians(angle)) y2 = cy + 50 * np.sin(np.radians(angle)) dwg.add(dwg.line(start=(x1, y1), end=(x2, y2), stroke='darkbrown', stroke_width=2)) def _draw_house(self, dwg, width, height): """Draw a house-like sketch""" cx, cy = width//2, height//2 # Base dwg.add(dwg.rect(insert=(cx-80, cy), size=(160, 100), fill='lightblue', stroke='blue', stroke_width=3)) # Roof dwg.add(dwg.polygon(points=[(cx-100, cy), (cx, cy-80), (cx+100, cy)], fill='red', stroke='darkred', stroke_width=3)) # Door dwg.add(dwg.rect(insert=(cx-20, cy+40), size=(40, 60), fill='brown', stroke='darkbrown', stroke_width=2)) # Windows dwg.add(dwg.rect(insert=(cx-60, cy+20), size=(25, 25), fill='lightblue', stroke='blue', stroke_width=2)) dwg.add(dwg.rect(insert=(cx+35, cy+20), size=(25, 25), fill='lightblue', stroke='blue', stroke_width=2)) # Chimney dwg.add(dwg.rect(insert=(cx+60, cy-60), size=(15, 40), fill='gray', stroke='darkgray', stroke_width=2)) def _draw_mountain(self, dwg, width, height): """Draw a mountain landscape""" cx, cy = width//2, height//2 # Mountains dwg.add(dwg.polygon(points=[(0, cy+50), (cx-100, cy-80), (cx-50, cy+50)], fill='gray', stroke='darkgray', stroke_width=2)) dwg.add(dwg.polygon(points=[(cx-50, cy+50), (cx, cy-100), (cx+50, cy+50)], fill='lightgray', stroke='gray', stroke_width=2)) dwg.add(dwg.polygon(points=[(cx+50, cy+50), (cx+100, cy-60), (width, cy+50)], fill='gray', stroke='darkgray', stroke_width=2)) # Snow caps dwg.add(dwg.polygon(points=[(cx-20, cy-60), (cx, cy-100), (cx+20, cy-60)], fill='white')) # Ground dwg.add(dwg.rect(insert=(0, cy+50), size=(width, height-cy-50), fill='lightgreen')) def _draw_abstract(self, dwg, width, height, num_paths): """Draw abstract shapes""" colors = ['red', 'blue', 'green', 'orange', 'purple', 'pink', 'yellow'] for i in range(num_paths): x = np.random.randint(50, width-50) y = np.random.randint(50, height-50) r = np.random.randint(10, 40) color = np.random.choice(colors) dwg.add(dwg.circle(center=(x, y), r=r, fill='none', stroke=color, stroke_width=np.random.randint(1, 4))) def _draw_icon_style(self, dwg, prompt, width, height): """Draw in clean icon style""" cx, cy = width//2, height//2 if 'home' in prompt.lower() or 'house' in prompt.lower(): # Simple house icon dwg.add(dwg.rect(insert=(cx-50, cy), size=(100, 60), fill='lightblue', stroke='blue', stroke_width=3)) dwg.add(dwg.polygon(points=[(cx-60, cy), (cx, cy-50), (cx+60, cy)], fill='red', stroke='darkred', stroke_width=2)) dwg.add(dwg.rect(insert=(cx-15, cy+20), size=(30, 40), fill='brown')) else: # Generic icon dwg.add(dwg.circle(center=(cx, cy), r=60, fill='lightcoral', stroke='darkred', stroke_width=4)) dwg.add(dwg.rect(insert=(cx-30, cy-30), size=(60, 60), fill='none', stroke='white', stroke_width=3)) def _draw_pixel_art(self, dwg, prompt, width, height): """Draw in pixel art style""" pixel_size = 16 colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF', '#FFFFFF'] for i in range(0, width, pixel_size): for j in range(0, height, pixel_size): if np.random.random() > 0.7: color = np.random.choice(colors) dwg.add(dwg.rect(insert=(i, j), size=(pixel_size, pixel_size), fill=color)) def _draw_abstract_art(self, dwg, prompt, width, height, num_paths): """Draw abstract art style""" for i in range(num_paths): # Create flowing curves start_x = np.random.randint(0, width) start_y = np.random.randint(0, height) end_x = np.random.randint(0, width) end_y = np.random.randint(0, height) ctrl1_x = np.random.randint(0, width) ctrl1_y = np.random.randint(0, height) ctrl2_x = np.random.randint(0, width) ctrl2_y = np.random.randint(0, height) path_data = f"M {start_x},{start_y} C {ctrl1_x},{ctrl1_y} {ctrl2_x},{ctrl2_y} {end_x},{end_y}" color = f'hsl({np.random.randint(0, 360)}, 70%, 50%)' dwg.add(dwg.path(d=path_data, fill='none', stroke=color, stroke_width=np.random.randint(2, 6))) def _apply_colorize_effect(self, dwg, prompt, width, height): """Apply colorize editing effect""" cx, cy = width//2, height//2 colors = ['red', 'green', 'blue', 'orange', 'purple'] for i, color in enumerate(colors): x = 50 + i * 80 y = cy dwg.add(dwg.circle(center=(x, y), r=30, fill=color, opacity=0.7)) dwg.add(dwg.text('COLORIZED', insert=(cx, cy-50), text_anchor='middle', font_size='20px', fill='black')) def _apply_stylize_effect(self, dwg, prompt, width, height): """Apply stylize editing effect""" cx, cy = width//2, height//2 for i in range(8): angle = i * 45 x = cx + 80 * np.cos(np.radians(angle)) y = cy + 80 * np.sin(np.radians(angle)) dwg.add(dwg.rect(insert=(x-10, y-10), size=(20, 20), fill='none', stroke='black', stroke_width=2, transform=f'rotate({angle} {x} {y})')) dwg.add(dwg.text('STYLIZED', insert=(cx, cy), text_anchor='middle', font_size='20px', fill='blue')) def _apply_modify_effect(self, dwg, prompt, width, height): """Apply modify editing effect""" cx, cy = width//2, height//2 dwg.add(dwg.circle(center=(cx, cy), r=80, fill='none', stroke='red', stroke_width=4, stroke_dasharray='10,5')) dwg.add(dwg.text('MODIFIED', insert=(cx, cy), text_anchor='middle', font_size='16px', fill='red')) # Add some modification indicators for i in range(4): angle = i * 90 x = cx + 100 * np.cos(np.radians(angle)) y = cy + 100 * np.sin(np.radians(angle)) dwg.add(dwg.circle(center=(x, y), r=10, fill='red')) def _generate_error_svg(self, error_msg, width=512, height=512): """Generate an error SVG""" dwg = svgwrite.Drawing(size=(f'{width}px', f'{height}px')) dwg.add(dwg.rect(insert=(0, 0), size=('100%', '100%'), fill='white')) dwg.add(dwg.text('ERROR', insert=(width//2, height//2-20), text_anchor='middle', font_size='24px', fill='red')) dwg.add(dwg.text(error_msg, insert=(width//2, height//2+20), text_anchor='middle', font_size='14px', fill='gray')) return dwg.tostring() # Global API instance api = UnifiedVectorGraphicsAPI() # Health endpoints @app.route('/health', methods=['GET']) def health(): return jsonify({ 'status': 'healthy', 'models': ['DiffSketcher', 'SVGDreamer', 'DiffSketchEdit'], 'diffvg_available': api.diffvg_available, 'stable_diffusion_available': api.sd_available, 'clip_available': api.clip_available }) @app.route('/diffsketcher/health', methods=['GET']) def diffsketcher_health(): return jsonify({'status': 'healthy', 'model': 'DiffSketcher'}) @app.route('/svgdreamer/health', methods=['GET']) def svgdreamer_health(): return jsonify({'status': 'healthy', 'model': 'SVGDreamer'}) @app.route('/diffsketchedit/health', methods=['GET']) def diffsketchedit_health(): return jsonify({'status': 'healthy', 'model': 'DiffSketchEdit'}) # DiffSketcher endpoints @app.route('/diffsketcher/generate', methods=['POST']) @app.route('/diffsketcher/generate_base64', methods=['POST']) def diffsketcher_generate(): try: data = request.json prompt = data.get('prompt', 'a simple drawing') num_paths = data.get('num_paths', 16) num_iter = data.get('num_iter', 500) width = data.get('width', 512) height = data.get('height', 512) svg_content = api.generate_diffsketcher_svg(prompt, num_paths, num_iter, width, height) if 'base64' in request.path: svg_b64 = base64.b64encode(svg_content.encode()).decode() return jsonify({ 'svg_base64': svg_b64, 'prompt': prompt, 'model': 'DiffSketcher' }) else: with tempfile.NamedTemporaryFile(mode='w', suffix='.svg', delete=False) as f: f.write(svg_content) temp_path = f.name return send_file(temp_path, as_attachment=True, download_name='diffsketcher_output.svg', mimetype='image/svg+xml') except Exception as e: return jsonify({'error': str(e)}), 500 # SVGDreamer endpoints @app.route('/svgdreamer/generate', methods=['POST']) @app.route('/svgdreamer/generate_base64', methods=['POST']) def svgdreamer_generate(): try: data = request.json prompt = data.get('prompt', 'a simple icon') style = data.get('style', 'iconography') num_paths = data.get('num_paths', 16) width = data.get('width', 512) height = data.get('height', 512) svg_content = api.generate_svgdreamer_svg(prompt, style, num_paths, width, height) if 'base64' in request.path: svg_b64 = base64.b64encode(svg_content.encode()).decode() return jsonify({ 'svg_base64': svg_b64, 'prompt': prompt, 'style': style, 'model': 'SVGDreamer' }) else: with tempfile.NamedTemporaryFile(mode='w', suffix='.svg', delete=False) as f: f.write(svg_content) temp_path = f.name return send_file(temp_path, as_attachment=True, download_name='svgdreamer_output.svg', mimetype='image/svg+xml') except Exception as e: return jsonify({'error': str(e)}), 500 # DiffSketchEdit endpoints @app.route('/diffsketchedit/edit', methods=['POST']) @app.route('/diffsketchedit/edit_base64', methods=['POST']) def diffsketchedit_edit(): try: data = request.json input_svg = data.get('input_svg', None) prompt = data.get('prompt', 'edit this sketch') edit_type = data.get('edit_type', 'modify') strength = data.get('strength', 0.7) width = data.get('width', 512) height = data.get('height', 512) svg_content = api.edit_diffsketchedit_svg(input_svg, prompt, edit_type, strength, width, height) if 'base64' in request.path: svg_b64 = base64.b64encode(svg_content.encode()).decode() return jsonify({ 'svg_base64': svg_b64, 'prompt': prompt, 'edit_type': edit_type, 'model': 'DiffSketchEdit' }) else: with tempfile.NamedTemporaryFile(mode='w', suffix='.svg', delete=False) as f: f.write(svg_content) temp_path = f.name return send_file(temp_path, as_attachment=True, download_name='diffsketchedit_output.svg', mimetype='image/svg+xml') except Exception as e: return jsonify({'error': str(e)}), 500 # Root endpoint with API documentation @app.route('/', methods=['GET']) def api_docs(): return '''
✓ All models are running and generating proper SVG content!
# Test DiffSketcher curl -X POST http://localhost:5000/diffsketcher/generate_base64 \\ -H "Content-Type: application/json" \\ -d '{"prompt": "a beautiful cat drawing", "num_paths": 16}' # Test SVGDreamer curl -X POST http://localhost:5000/svgdreamer/generate_base64 \\ -H "Content-Type: application/json" \\ -d '{"prompt": "house icon", "style": "iconography"}' # Test DiffSketchEdit curl -X POST http://localhost:5000/diffsketchedit/edit_base64 \\ -H "Content-Type: application/json" \\ -d '{"prompt": "make it colorful", "edit_type": "colorize"}'''' if __name__ == '__main__': print("Starting Unified Vector Graphics Models API Server...") print("=" * 60) print(f"DiffVG Available: {api.diffvg_available}") print(f"Stable Diffusion Available: {api.sd_available}") print(f"CLIP Available: {api.clip_available}") print("=" * 60) print("Server will start on http://localhost:5000") print("API documentation available at: http://localhost:5000") app.run(host='0.0.0.0', port=5000, debug=False)