File size: 5,181 Bytes
ee78287
 
 
 
8f9eca2
131839e
ee78287
 
 
 
 
8f9eca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee78287
 
 
 
131839e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee78287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e7c333
 
 
ee78287
 
 
 
0e7c333
 
 
 
 
 
 
 
 
 
 
 
 
ee78287
 
0e7c333
 
 
ee78287
 
 
 
 
 
 
 
 
0e7c333
ee78287
 
 
 
 
 
0e7c333
ee78287
 
0e7c333
ee78287
 
 
0e7c333
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from PIL import Image
import numpy as np
import torch
import os
from diffusers import FluxKontextPipeline  # Fixed: Import the correct pipeline
from diffusers.utils import load_image
from huggingface_hub import HfApi, login
from huggingface_hub.utils import HfHubHTTPError
from ultralytics import YOLO

# HuggingFace token setup
import os

# Option 1: Set your token directly (not recommended for production)
# token = "hf_your_token_here"  # Replace with your actual token

# Option 2: Use environment variable (recommended)
token = os.getenv("HF_TOKEN")  # Changed to match your secret name

# Option 3: Skip login if no token (will work for public models)
if token:
    try:
        login(token=token)
        api = HfApi(token=token)
        user = api.whoami()
        print("βœ… HuggingFace token valid. Logged in as:", user["name"])
    except HfHubHTTPError as e:
        print("❌ Invalid or expired HuggingFace token.")
        print("Error:", e)
else:
    print("⚠️ No HuggingFace token found. Using public access only.")
    api = HfApi()

# ───── Load FLUX-Kontext Pipeline ─────
device = "cuda" if torch.cuda.is_available() else "cpu"

# Try the quantized version first, fallback to official if needed
try:
    pipe = FluxKontextPipeline.from_pretrained(
        "HighCWu/FLUX.1-Kontext-dev-bnb-hqq-4bit",
        torch_dtype=torch.bfloat16,
    )
    print("βœ… FLUX-Kontext quantized pipeline loaded")
except Exception as e:
    print(f"⚠️ Quantized model failed: {e}")
    print("πŸ“₯ Falling back to official model...")
    pipe = FluxKontextPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-Kontext-dev",
        torch_dtype=torch.bfloat16,
    )
    print("βœ… FLUX-Kontext official pipeline loaded")

pipe.to(device)

# ───── Load YOLOv8 Segmentation Model ─────
yolo = YOLO('yolov8s-seg.pt')  # can be replaced with your own weights
print("βœ… YOLOv8s-seg loaded")

# ───── Valid Items to Detect ─────
valid_room_items = [
    'sofa', 'couch', 'chair', 'table', 'bed', 'lamp', 'tv', 'cabinet', 'desk',
    'stool', 'curtain', 'carpet', 'painting', 'mirror', 'shelf',
    'pillow', 'cushion', 'potted plant', 'plant', 'vase', 'rug', 'bowl', 'book'
]

# ───── Helper Functions ─────
def run_yolo(image):
    img_resized = image.resize((512, 512))
    result = yolo.predict(source=np.array(img_resized), imgsz=512, device=device, save=False)[0]
    all_labels = [result.names[int(c)] for c in result.boxes.cls]
    detected_labels = sorted(set(label.lower().rstrip('s') for label in all_labels))
    filtered = [item for item in detected_labels if item in valid_room_items]
    return list(set(filtered))

def generate_prompt(mode, selections=None):
    if mode == 'all':
        return "Remove everything from the room including all furniture like bed, sofa, couch, table, lamp, chairs, curtains, decor items, etc., except the walls and carpet."
    elif mode == 'select':
        if not selections:
            return "Don't remove anything"
        elif len(selections) == 1:
            return f"Remove {selections[0]} from the room"
        else:
            return "Remove " + ", ".join(selections[:-1]) + " and " + selections[-1] + " from the room"
    else:
        raise ValueError("Invalid mode")

def process(image, mode, selections):
    if image is None:
        return None
        
    original_size = image.size
    img_resized = image.resize((512, 512))
    prompt = generate_prompt(mode, selections)
    print("🧠 Auto-prompt:", prompt)
    
    try:
        result = pipe(
            prompt=prompt,
            image=img_resized,
            guidance_scale=7.0,
            num_inference_steps=35
        ).images[0]
        result_upscaled = result.resize(original_size)
        return result_upscaled
    except Exception as e:
        print(f"Error during processing: {e}")
        return None

def interface_main(image, mode):
    if image is None:
        return gr.update(visible=False), gr.update(visible=True)
        
    if mode == 'select':
        detected = run_yolo(image)
        return gr.update(visible=True, choices=detected, value=[]), gr.update(visible=True)
    else:
        return gr.update(visible=False), gr.update(visible=True)

# ───── Gradio UI ─────
with gr.Blocks() as demo:
    gr.Markdown("## πŸ›‹οΈ Room Object Remover using FLUX + YOLOv8")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Room Image", type="pil")
            mode = gr.Radio(choices=["all", "select"], label="Choose Mode", value="all")
            detected_items = gr.CheckboxGroup(choices=[], label="Select objects to remove", visible=False)
            run_button = gr.Button("Run")
        
        with gr.Column():
            output_image = gr.Image(label="Output Image", interactive=False)
    
    mode.change(fn=interface_main, inputs=[image_input, mode], outputs=[detected_items, run_button])
    run_button.click(fn=process, inputs=[image_input, mode, detected_items], outputs=[output_image])

if __name__ == "__main__":
    demo.launch()