File size: 6,953 Bytes
527eee1
8247a04
 
 
 
aa0c79b
8247a04
76b0872
205fcdc
8247a04
 
 
 
 
 
 
 
 
 
 
3bca564
8247a04
aa0c79b
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
7ddc847
 
 
 
8247a04
aa0c79b
58b6352
aa0c79b
7ddc847
 
8247a04
7ddc847
 
 
 
 
8247a04
 
7ddc847
58b6352
8247a04
 
 
7ddc847
 
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ddc847
 
 
8247a04
 
7ddc847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8247a04
879971e
 
 
 
 
8247a04
879971e
 
 
 
 
7ddc847
 
 
 
879971e
 
7ddc847
879971e
7ddc847
879971e
 
 
 
 
8247a04
879971e
 
8247a04
7ddc847
8247a04
 
879971e
7ddc847
 
8247a04
7ddc847
 
 
 
 
 
 
 
205fcdc
527eee1
58b6352
 
 
 
3bca564
58b6352
 
 
3bca564
58b6352
 
 
 
3bca564
58b6352
3bca564
205fcdc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import spaces
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config
import random

from diffusers import DiffusionPipeline, AutoPipelineForText2Image
import torch

class DiffusionInference:
    def __init__(self, api_key=None):
        """
        Initialize the inference client with the Hugging Face API token.
        """
        self.api_key = api_key or config.HF_TOKEN
        self.client = InferenceClient(
            provider="hf-inference",
            api_key=self.api_key,
        )
        self.device = torch.device("cuda" if torch.cuda else "cpu")

    def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
        """
        Generate an image from a text prompt.
        
        Args:
            prompt (str): The text prompt to guide image generation
            model_name (str, optional): The model to use for inference
            negative_prompt (str, optional): What not to include in the image
            **kwargs: Additional parameters to pass to the model
            
        Returns:
            PIL.Image: The generated image
        """
        model = model_name or config.DEFAULT_TEXT2IMG_MODEL
        
        # Create parameters dictionary for all keyword arguments
        params = {
            "prompt": prompt,
        }
        
        # Handle seed parameter

        
        # Add negative prompt if provided
        if negative_prompt is not None:
            params["negative_prompt"] = negative_prompt
        
        # Add any other parameters
        for k, v in kwargs.items():
            if k not in ["prompt", "model", "negative_prompt"]:
                params[k] = v
        
        try:
            # Call the API with all parameters as kwargs
            image = self.run_text_to_image_pipeline(model, seed, **params)
            return image
        except Exception as e:
            print(f"Error generating image: {e}")
            print(f"Model: {model}")
            print(f"Prompt: {prompt}")
            raise

    def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
        """
        Generate a new image from an input image and optional prompt.
        
        Args:
            image (PIL.Image or str): Input image or path to image
            prompt (str, optional): Text prompt to guide the transformation
            model_name (str, optional): The model to use for inference
            negative_prompt (str, optional): What not to include in the image
            **kwargs: Additional parameters to pass to the model
            
        Returns:
            PIL.Image: The generated image
        """
        import tempfile
        import os
        
        model = model_name or config.DEFAULT_IMG2IMG_MODEL
        
        # Create a temporary file for the image if it's a PIL Image
        temp_file = None
        try:
            # Handle different image input types
            if isinstance(image, str):
                # If it's already a file path, use it directly
                image_path = image
            elif isinstance(image, Image.Image):
                # If it's a PIL Image, save it to a temporary file
                temp_dir = tempfile.gettempdir()
                temp_file = os.path.join(temp_dir, "temp_image.png")
                image.save(temp_file, format="PNG")
                image_path = temp_file
            else:
                # If it's something else, try to convert it to a PIL Image first
                try:
                    pil_image = Image.fromarray(image)
                    temp_dir = tempfile.gettempdir()
                    temp_file = os.path.join(temp_dir, "temp_image.png")
                    pil_image.save(temp_file, format="PNG")
                    image_path = temp_file
                except Exception as e:
                    raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
            
            # Create a NEW InferenceClient for this call to avoid any potential state issues
            client = InferenceClient(
                provider="hf-inference",
                api_key=self.api_key,
            )
            
            # Create the parameter dict with only the non-None values
            params = {}
            # Only add parameters that are not None
            if model is not None:
                params["model"] = model
            if prompt is not None:
                params["prompt"] = prompt
            if negative_prompt is not None:
                params["negative_prompt"] = negative_prompt
            
            # Add additional kwargs, but filter out any that might create conflicts
            for k, v in kwargs.items():
                if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]:
                    params[k] = v
                    
            # Debug the parameters we're sending
            print(f"DEBUG: Calling image_to_image with:")
            print(f"- Image path: {image_path}")
            print(f"- Parameters: {params}")
            
            # Make the API call
            result = client.image_to_image(image_path, **params)
            return result
            
        except Exception as e:
            print(f"Error transforming image: {e}")
            print(f"Image type: {type(image)}")
            print(f"Model: {model}")
            print(f"Prompt: {prompt}")
            raise
            
        finally:
            # Clean up the temporary file if it was created
            if temp_file and os.path.exists(temp_file):
                try:
                    os.remove(temp_file)
                except Exception as e:
                    print(f"Warning: Could not delete temporary file {temp_file}: {e}")

    @spaces.GPU
    def run_text_to_image_pipeline(self, model_name, seed, **kwargs):
        if seed is not None:
            try:
                # Convert to integer and add to params
                generator = torch.Generator(device=self.device).manual_seed(seed)
            except (ValueError, TypeError):
                # Use random seed if conversion fails
                random_seed = random.randint(0, 3999999999)  # Max 32-bit integer
                generator = torch.Generator(device=self.device).manual_seed(random_seed)
                print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
        else:
            # Generate random seed when none is provided
            random_seed = random.randint(0, 3999999999)  # Max 32-bit integer
            generator = torch.Generator(device=self.device).manual_seed(random_seed)
            print(f"Using random seed: {random_seed}")
        pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device)
        image = pipeline(**kwargs).images[0]
        return image