| from typing import Dict, Any, List, Union | |
| import torch | |
| import base64 | |
| import io | |
| from PIL import Image | |
| class Pipeline: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Initializing diffsketcher pipeline on {self.device}") | |
| def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |
| # Extract prompt from the input data | |
| prompt = inputs.get("prompt", "") | |
| if not prompt and "prompts" in inputs: | |
| prompts = inputs.get("prompts", [""]) | |
| prompt = prompts[0] if prompts else "" | |
| # Generate a placeholder SVG | |
| svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>' | |
| # Create a placeholder image | |
| image = Image.new('RGB', (512, 512), color = (100, 100, 100)) | |
| # Convert the image to base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Return the results | |
| return { | |
| "svg": svg, | |
| "image": img_str | |
| } | |