import gradio as gr from PIL import Image import numpy as np import torch import os from diffusers import FluxKontextPipeline # Fixed: Import the correct pipeline from diffusers.utils import load_image from huggingface_hub import HfApi, login from huggingface_hub.utils import HfHubHTTPError from ultralytics import YOLO # HuggingFace token setup import os # Option 1: Set your token directly (not recommended for production) # token = "hf_your_token_here" # Replace with your actual token # Option 2: Use environment variable (recommended) token = os.getenv("HF_TOKEN") # Changed to match your secret name # Option 3: Skip login if no token (will work for public models) if token: try: login(token=token) api = HfApi(token=token) user = api.whoami() print("✅ HuggingFace token valid. Logged in as:", user["name"]) except HfHubHTTPError as e: print("❌ Invalid or expired HuggingFace token.") print("Error:", e) else: print("⚠️ No HuggingFace token found. Using public access only.") api = HfApi() # ───── Load FLUX-Kontext Pipeline ───── device = "cuda" if torch.cuda.is_available() else "cpu" # Try the quantized version first, fallback to official if needed try: pipe = FluxKontextPipeline.from_pretrained( "HighCWu/FLUX.1-Kontext-dev-bnb-hqq-4bit", torch_dtype=torch.bfloat16, ) print("✅ FLUX-Kontext quantized pipeline loaded") except Exception as e: print(f"⚠️ Quantized model failed: {e}") print("📥 Falling back to official model...") pipe = FluxKontextPipeline.from_pretrained( "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, ) print("✅ FLUX-Kontext official pipeline loaded") pipe.to(device) # ───── Load YOLOv8 Segmentation Model ───── yolo = YOLO('yolov8s-seg.pt') # can be replaced with your own weights print("✅ YOLOv8s-seg loaded") # ───── Valid Items to Detect ───── valid_room_items = [ 'sofa', 'couch', 'chair', 'table', 'bed', 'lamp', 'tv', 'cabinet', 'desk', 'stool', 'curtain', 'carpet', 'painting', 'mirror', 'shelf', 'pillow', 'cushion', 'potted plant', 'plant', 'vase', 'rug', 'bowl', 'book' ] # ───── Helper Functions ───── def run_yolo(image): img_resized = image.resize((512, 512)) result = yolo.predict(source=np.array(img_resized), imgsz=512, device=device, save=False)[0] all_labels = [result.names[int(c)] for c in result.boxes.cls] detected_labels = sorted(set(label.lower().rstrip('s') for label in all_labels)) filtered = [item for item in detected_labels if item in valid_room_items] return list(set(filtered)) def generate_prompt(mode, selections=None): if mode == 'all': return "Remove everything from the room including all furniture like bed, sofa, couch, table, lamp, chairs, curtains, decor items, etc., except the walls and carpet." elif mode == 'select': if not selections: return "Don't remove anything" elif len(selections) == 1: return f"Remove {selections[0]} from the room" else: return "Remove " + ", ".join(selections[:-1]) + " and " + selections[-1] + " from the room" else: raise ValueError("Invalid mode") def process(image, mode, selections): if image is None: return None original_size = image.size img_resized = image.resize((512, 512)) prompt = generate_prompt(mode, selections) print("🧠 Auto-prompt:", prompt) try: result = pipe( prompt=prompt, image=img_resized, guidance_scale=7.0, num_inference_steps=35 ).images[0] result_upscaled = result.resize(original_size) return result_upscaled except Exception as e: print(f"Error during processing: {e}") return None def interface_main(image, mode): if image is None: return gr.update(visible=False), gr.update(visible=True) if mode == 'select': detected = run_yolo(image) return gr.update(visible=True, choices=detected, value=[]), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=True) # ───── Gradio UI ───── with gr.Blocks() as demo: gr.Markdown("## 🛋️ Room Object Remover using FLUX + YOLOv8") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Room Image", type="pil") mode = gr.Radio(choices=["all", "select"], label="Choose Mode", value="all") detected_items = gr.CheckboxGroup(choices=[], label="Select objects to remove", visible=False) run_button = gr.Button("Run") with gr.Column(): output_image = gr.Image(label="Output Image", interactive=False) mode.change(fn=interface_main, inputs=[image_input, mode], outputs=[detected_items, run_button]) run_button.click(fn=process, inputs=[image_input, mode, detected_items], outputs=[output_image]) if __name__ == "__main__": demo.launch()