File size: 6,953 Bytes
527eee1 8247a04 aa0c79b 8247a04 76b0872 205fcdc 8247a04 3bca564 8247a04 aa0c79b 8247a04 7ddc847 8247a04 aa0c79b 58b6352 aa0c79b 7ddc847 8247a04 7ddc847 8247a04 7ddc847 58b6352 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 879971e 8247a04 879971e 7ddc847 879971e 7ddc847 879971e 7ddc847 879971e 8247a04 879971e 8247a04 7ddc847 8247a04 879971e 7ddc847 8247a04 7ddc847 205fcdc 527eee1 58b6352 3bca564 58b6352 3bca564 58b6352 3bca564 58b6352 3bca564 205fcdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import spaces
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config
import random
from diffusers import DiffusionPipeline, AutoPipelineForText2Image
import torch
class DiffusionInference:
def __init__(self, api_key=None):
"""
Initialize the inference client with the Hugging Face API token.
"""
self.api_key = api_key or config.HF_TOKEN
self.client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
self.device = torch.device("cuda" if torch.cuda else "cpu")
def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
"""
Generate an image from a text prompt.
Args:
prompt (str): The text prompt to guide image generation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
model = model_name or config.DEFAULT_TEXT2IMG_MODEL
# Create parameters dictionary for all keyword arguments
params = {
"prompt": prompt,
}
# Handle seed parameter
# Add negative prompt if provided
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add any other parameters
for k, v in kwargs.items():
if k not in ["prompt", "model", "negative_prompt"]:
params[k] = v
try:
# Call the API with all parameters as kwargs
image = self.run_text_to_image_pipeline(model, seed, **params)
return image
except Exception as e:
print(f"Error generating image: {e}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
"""
Generate a new image from an input image and optional prompt.
Args:
image (PIL.Image or str): Input image or path to image
prompt (str, optional): Text prompt to guide the transformation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
import tempfile
import os
model = model_name or config.DEFAULT_IMG2IMG_MODEL
# Create a temporary file for the image if it's a PIL Image
temp_file = None
try:
# Handle different image input types
if isinstance(image, str):
# If it's already a file path, use it directly
image_path = image
elif isinstance(image, Image.Image):
# If it's a PIL Image, save it to a temporary file
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
image.save(temp_file, format="PNG")
image_path = temp_file
else:
# If it's something else, try to convert it to a PIL Image first
try:
pil_image = Image.fromarray(image)
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
pil_image.save(temp_file, format="PNG")
image_path = temp_file
except Exception as e:
raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
# Create a NEW InferenceClient for this call to avoid any potential state issues
client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
# Create the parameter dict with only the non-None values
params = {}
# Only add parameters that are not None
if model is not None:
params["model"] = model
if prompt is not None:
params["prompt"] = prompt
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add additional kwargs, but filter out any that might create conflicts
for k, v in kwargs.items():
if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]:
params[k] = v
# Debug the parameters we're sending
print(f"DEBUG: Calling image_to_image with:")
print(f"- Image path: {image_path}")
print(f"- Parameters: {params}")
# Make the API call
result = client.image_to_image(image_path, **params)
return result
except Exception as e:
print(f"Error transforming image: {e}")
print(f"Image type: {type(image)}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
finally:
# Clean up the temporary file if it was created
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
except Exception as e:
print(f"Warning: Could not delete temporary file {temp_file}: {e}")
@spaces.GPU
def run_text_to_image_pipeline(self, model_name, seed, **kwargs):
if seed is not None:
try:
# Convert to integer and add to params
generator = torch.Generator(device=self.device).manual_seed(seed)
except (ValueError, TypeError):
# Use random seed if conversion fails
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
generator = torch.Generator(device=self.device).manual_seed(random_seed)
print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
else:
# Generate random seed when none is provided
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
generator = torch.Generator(device=self.device).manual_seed(random_seed)
print(f"Using random seed: {random_seed}")
pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device)
image = pipeline(**kwargs).images[0]
return image |