interior-ai-designer-ai / generation_function.py
Varhal's picture
Queued imports
e964fca verified
# generation_function.py
import spaces # Import spaces first!
import torch
import numpy as np
import time
import imageio
import gc
from PIL import Image
# Import necessary components from your new modules
from model_loader import controlnet_pipe, inpaint_pipe, api
from preprocessor import Preprocessor
from style_utils import apply_style, style_list
from utils import randomize_seed_fn
from config import API_KEY
preprocessor = Preprocessor()
# Preprocessor is loaded in app.py
@spaces.GPU(duration=12)
@torch.inference_mode()
def generate_interior_design(
image_np: np.ndarray,
mask_np: np.ndarray | None, # Add mask input (can be None)
mode: str, # Add mode selection input
style_selection: str,
prompt: str,
a_prompt: str,
n_prompt: str,
num_images: int, # Note: Pipeline currently only generates 1 image
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
randomize_seed: bool,
):
# Convert numpy arrays to PIL Images
image = Image.fromarray(image_np.astype(np.uint8)).convert("RGB")
mask = Image.fromarray(mask_np[:, :, 0].astype(np.uint8), 'L') if mask_np is not None else None # Convert mask to grayscale PIL Image
# Apply seed randomization
current_seed = randomize_seed_fn(seed, randomize_seed)
generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed)
print(f"Using processed seed: {current_seed}")
# Construct the full prompt (can be used by both pipelines)
style_prompt_text = apply_style(style_selection)
prompt_parts = []
if prompt:
prompt_parts.append(f"Photo from Pinterest of {prompt}")
else:
prompt_parts.append("Photo from Pinterest of interior space")
if style_prompt_text:
prompt_parts.append(style_prompt_text)
if a_prompt:
prompt_parts.append(a_prompt)
full_prompt = ", ".join(filter(None, prompt_parts))
negative_prompt = str(n_prompt)
print(f"Using prompt: {full_prompt}")
print(f"Using negative prompt: {negative_prompt}")
print(f"Selected mode: {mode}")
initial_result = None
if mode == "ControlNet":
if preprocessor.name != "NormalBae":
preprocessor.load("NormalBae")
# Ensure preprocessor is on the correct device
preprocessor_device = "cuda" if torch.cuda.is_available() else "cpu"
if hasattr(preprocessor.model, 'device') and preprocessor.model.device.type != preprocessor_device:
print(f"Moving preprocessor model to {preprocessor_device}")
try:
preprocessor.model.to(preprocessor_device)
except Exception as e:
print(f"Error moving preprocessor model to {preprocessor_device}: {e}")
pass
control_image = preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
)
controlnet_pipe_device = "cuda" if torch.cuda.is_available() else "cpu"
if hasattr(controlnet_pipe, 'device') and controlnet_pipe.device.type != controlnet_pipe_device:
print(f"Moving controlnet pipe to {controlnet_pipe_device}")
try:
controlnet_pipe.to(controlnet_pipe_device)
except Exception as e:
print(f"Error moving controlnet pipe to {controlnet_pipe_device}: {e}")
with torch.no_grad():
initial_result = controlnet_pipe(
prompt=full_prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
num_inference_steps=num_steps,
generator=generator,
image=control_image,
).images[0]
elif mode == "Inpainting":
if mask is None:
raise gr.Error("Inpainting mode requires a mask.") # Provide user feedback
inpaint_pipe_device = "cuda" if torch.cuda.is_available() else "cpu"
if hasattr(inpaint_pipe, 'device') and inpaint_pipe.device.type != inpaint_pipe_device:
print(f"Moving inpaint pipe to {inpaint_pipe_device}")
try:
inpaint_pipe.to(inpaint_pipe_device)
except Exception as e:
print(f"Error moving inpaint pipe to {inpaint_pipe_device}: {e}")
with torch.no_grad():
initial_result = inpaint_pipe(
prompt=full_prompt,
negative_prompt=negative_prompt,
image=image, # Pass original image
mask_image=mask, # Pass the mask image
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
generator=generator,
).images[0]
# Save and upload results (optional) - This part can remain the same
try:
if initial_result: # Only save/upload if a result was generated
timestamp = int(time.time())
results_path = f"{timestamp}_output.jpg"
imageio.imsave(results_path, initial_result)
if API_KEY:
print(f"Uploading result image to broyang/interior-ai-outputs/{results_path}")
try:
api.upload_file(
path_or_fileobj=results_path,
path_in_repo=results_path,
repo_id="broyang/interior-ai-outputs",
repo_type="dataset",
token=API_KEY,
run_as_future=True,
)
except Exception as e:
print(f"Error uploading file to Hugging Face Hub: {e}")
else:
print("Hugging Face API Key not found, skipping file upload.")
except Exception as e:
print(f"Error saving or uploading image: {e}")
# Clean up CUDA memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if initial_result:
print(f"CUDA memory allocated after generation: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
else:
print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
if initial_result is None:
# Return a blank image or an error message if no result was generated
# This might happen if an unimplemented mode was selected
print("No result generated for the selected mode.")
return Image.new('RGB', (512, 512), (255, 255, 255)) # Return a blank white image
return initial_result