File size: 1,294 Bytes
5edea0f
13ea232
 
 
 
 
e942bd1
13ea232
e942bd1
13ea232
 
e942bd1
13ea232
 
 
 
 
 
e942bd1
13ea232
 
5edea0f
 
13ea232
e942bd1
13ea232
 
 
 
e942bd1
5edea0f
13ea232
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

from typing import Dict, Any, List, Union
import torch
import base64
import io
from PIL import Image

class Pipeline:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing diffsketcher pipeline on {self.device}")
    
    def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        # Extract prompt from the input data
        prompt = inputs.get("prompt", "")
        if not prompt and "prompts" in inputs:
            prompts = inputs.get("prompts", [""])
            prompt = prompts[0] if prompts else ""
        
        # Generate a placeholder SVG
        svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>'
        
        # Create a placeholder image
        image = Image.new('RGB', (512, 512), color = (100, 100, 100))
        
        # Convert the image to base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        
        # Return the results
        return {
            "svg": svg,
            "image": img_str
        }