Spaces:
Build error
Build error
| # generation_function.py | |
| import spaces # Import spaces first! | |
| import torch | |
| import numpy as np | |
| import time | |
| import imageio | |
| import gc | |
| from PIL import Image | |
| # Import necessary components from your new modules | |
| from model_loader import controlnet_pipe, inpaint_pipe, api | |
| from preprocessor import Preprocessor | |
| from style_utils import apply_style, style_list | |
| from utils import randomize_seed_fn | |
| from config import API_KEY | |
| preprocessor = Preprocessor() | |
| # Preprocessor is loaded in app.py | |
| def generate_interior_design( | |
| image_np: np.ndarray, | |
| mask_np: np.ndarray | None, # Add mask input (can be None) | |
| mode: str, # Add mode selection input | |
| style_selection: str, | |
| prompt: str, | |
| a_prompt: str, | |
| n_prompt: str, | |
| num_images: int, # Note: Pipeline currently only generates 1 image | |
| image_resolution: int, | |
| preprocess_resolution: int, | |
| num_steps: int, | |
| guidance_scale: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| ): | |
| # Convert numpy arrays to PIL Images | |
| image = Image.fromarray(image_np.astype(np.uint8)).convert("RGB") | |
| mask = Image.fromarray(mask_np[:, :, 0].astype(np.uint8), 'L') if mask_np is not None else None # Convert mask to grayscale PIL Image | |
| # Apply seed randomization | |
| current_seed = randomize_seed_fn(seed, randomize_seed) | |
| generator = torch.cuda.manual_seed(current_seed) if torch.cuda.is_available() else torch.manual_seed(current_seed) | |
| print(f"Using processed seed: {current_seed}") | |
| # Construct the full prompt (can be used by both pipelines) | |
| style_prompt_text = apply_style(style_selection) | |
| prompt_parts = [] | |
| if prompt: | |
| prompt_parts.append(f"Photo from Pinterest of {prompt}") | |
| else: | |
| prompt_parts.append("Photo from Pinterest of interior space") | |
| if style_prompt_text: | |
| prompt_parts.append(style_prompt_text) | |
| if a_prompt: | |
| prompt_parts.append(a_prompt) | |
| full_prompt = ", ".join(filter(None, prompt_parts)) | |
| negative_prompt = str(n_prompt) | |
| print(f"Using prompt: {full_prompt}") | |
| print(f"Using negative prompt: {negative_prompt}") | |
| print(f"Selected mode: {mode}") | |
| initial_result = None | |
| if mode == "ControlNet": | |
| if preprocessor.name != "NormalBae": | |
| preprocessor.load("NormalBae") | |
| # Ensure preprocessor is on the correct device | |
| preprocessor_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if hasattr(preprocessor.model, 'device') and preprocessor.model.device.type != preprocessor_device: | |
| print(f"Moving preprocessor model to {preprocessor_device}") | |
| try: | |
| preprocessor.model.to(preprocessor_device) | |
| except Exception as e: | |
| print(f"Error moving preprocessor model to {preprocessor_device}: {e}") | |
| pass | |
| control_image = preprocessor( | |
| image=image, | |
| image_resolution=image_resolution, | |
| detect_resolution=preprocess_resolution, | |
| ) | |
| controlnet_pipe_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if hasattr(controlnet_pipe, 'device') and controlnet_pipe.device.type != controlnet_pipe_device: | |
| print(f"Moving controlnet pipe to {controlnet_pipe_device}") | |
| try: | |
| controlnet_pipe.to(controlnet_pipe_device) | |
| except Exception as e: | |
| print(f"Error moving controlnet pipe to {controlnet_pipe_device}: {e}") | |
| with torch.no_grad(): | |
| initial_result = controlnet_pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=1, | |
| num_inference_steps=num_steps, | |
| generator=generator, | |
| image=control_image, | |
| ).images[0] | |
| elif mode == "Inpainting": | |
| if mask is None: | |
| raise gr.Error("Inpainting mode requires a mask.") # Provide user feedback | |
| inpaint_pipe_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if hasattr(inpaint_pipe, 'device') and inpaint_pipe.device.type != inpaint_pipe_device: | |
| print(f"Moving inpaint pipe to {inpaint_pipe_device}") | |
| try: | |
| inpaint_pipe.to(inpaint_pipe_device) | |
| except Exception as e: | |
| print(f"Error moving inpaint pipe to {inpaint_pipe_device}: {e}") | |
| with torch.no_grad(): | |
| initial_result = inpaint_pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, # Pass original image | |
| mask_image=mask, # Pass the mask image | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_steps, | |
| generator=generator, | |
| ).images[0] | |
| # Save and upload results (optional) - This part can remain the same | |
| try: | |
| if initial_result: # Only save/upload if a result was generated | |
| timestamp = int(time.time()) | |
| results_path = f"{timestamp}_output.jpg" | |
| imageio.imsave(results_path, initial_result) | |
| if API_KEY: | |
| print(f"Uploading result image to broyang/interior-ai-outputs/{results_path}") | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=results_path, | |
| path_in_repo=results_path, | |
| repo_id="broyang/interior-ai-outputs", | |
| repo_type="dataset", | |
| token=API_KEY, | |
| run_as_future=True, | |
| ) | |
| except Exception as e: | |
| print(f"Error uploading file to Hugging Face Hub: {e}") | |
| else: | |
| print("Hugging Face API Key not found, skipping file upload.") | |
| except Exception as e: | |
| print(f"Error saving or uploading image: {e}") | |
| # Clean up CUDA memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if initial_result: | |
| print(f"CUDA memory allocated after generation: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB") | |
| else: | |
| print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB") | |
| if initial_result is None: | |
| # Return a blank image or an error message if no result was generated | |
| # This might happen if an unimplemented mode was selected | |
| print("No result generated for the selected mode.") | |
| return Image.new('RGB', (512, 512), (255, 255, 255)) # Return a blank white image | |
| return initial_result | |