|
|
|
|
|
|
|
""" |
|
Simplified DiffSketcher model for text-to-SVG generation. |
|
""" |
|
|
|
import os |
|
import io |
|
import base64 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import clip |
|
import torch.nn.functional as F |
|
import xml.etree.ElementTree as ET |
|
import cairosvg |
|
|
|
class DiffSketcherModel: |
|
def __init__(self, model_dir): |
|
"""Initialize the DiffSketcher model""" |
|
self.model_dir = model_dir |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.clip_model_path = os.path.join(model_dir, "ViT-B-32.pt") |
|
if os.path.exists(self.clip_model_path): |
|
print(f"Loading CLIP model from {self.clip_model_path}") |
|
self.clip_model, _ = clip.load(self.clip_model_path, device=self.device) |
|
else: |
|
print(f"CLIP model not found at {self.clip_model_path}, downloading...") |
|
self.clip_model, _ = clip.load("ViT-B-32", device=self.device) |
|
|
|
|
|
self.clip_model.eval() |
|
|
|
print(f"DiffSketcher model initialized on device: {self.device}") |
|
|
|
def generate_svg(self, prompt, num_paths=10, width=512, height=512): |
|
"""Generate an SVG from a text prompt""" |
|
print(f"Generating SVG for prompt: {prompt}") |
|
|
|
|
|
with torch.no_grad(): |
|
text_features = self.clip_model.encode_text(clip.tokenize([prompt]).to(self.device)) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
|
<rect width="100%" height="100%" fill="#f0f0f0"/> |
|
<text x="50%" y="10%" font-family="Arial" font-size="20" text-anchor="middle">Generated by DiffSketcher</text> |
|
<text x="50%" y="50%" font-family="Arial" font-size="24" text-anchor="middle" font-weight="bold">{prompt}</text> |
|
""" |
|
|
|
|
|
for i in range(min(num_paths, text_features.shape[1])): |
|
|
|
feature_val = text_features[0, i % text_features.shape[1]].item() |
|
x = (feature_val + 1) * width / 2 |
|
y = ((i / num_paths) * 0.8 + 0.1) * height |
|
radius = abs(feature_val) * 50 + 10 |
|
hue = (feature_val + 1) * 180 |
|
|
|
|
|
svg_content += f"""<circle cx="{x}" cy="{y}" r="{radius}" fill="hsl({hue}, 70%, 60%)" opacity="0.7" />""" |
|
|
|
|
|
svg_content += "</svg>" |
|
|
|
return svg_content |
|
|
|
def svg_to_png(self, svg_content): |
|
"""Convert SVG content to PNG""" |
|
try: |
|
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
|
return png_data |
|
except Exception as e: |
|
print(f"Error converting SVG to PNG: {e}") |
|
|
|
image = Image.new("RGB", (512, 512), color="#ff0000") |
|
from PIL import ImageDraw |
|
draw = ImageDraw.Draw(image) |
|
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
|
|
|
|
|
buffer = io.BytesIO() |
|
image.save(buffer, format="PNG") |
|
return buffer.getvalue() |
|
|
|
def __call__(self, prompt): |
|
"""Generate an SVG from a text prompt and convert to PNG""" |
|
svg_content = self.generate_svg(prompt) |
|
png_data = self.svg_to_png(svg_content) |
|
|
|
|
|
image = Image.open(io.BytesIO(png_data)) |
|
|
|
|
|
response = { |
|
"svg": svg_content, |
|
"svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), |
|
"png_base64": base64.b64encode(png_data).decode("utf-8"), |
|
"image": image |
|
} |
|
|
|
return response |