Flux / app.py
nimraaaajhduksy's picture
Update app.py
28ec59d verified
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from ultralytics import YOLO
from PIL import Image
import numpy as np
import os
# ───── Model Loading ─────
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Required for gated access model
auth_token = os.getenv("HUGGINGFACE_TOKEN")
# Load FLUX (Kontext)
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
use_auth_token=auth_token,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
# Make sure your HF token has gated access to Flux
pipe.to(device)
pipe.enable_model_cpu_offload()
# Load YOLOv8
yolo = YOLO("yolov8x.pt")
# ───── Helper Function ─────
def detect_and_remove(image, mode, user_items=None):
original_size = image.size
img_resized = image.resize((512, 512))
res = yolo.predict(source=np.array(img_resized), imgsz=512, device=device, save=False)[0]
labels = [res.names[int(c)] for c in res.boxes.cls]
# Furniture keywords
furniture_labels = ['sofa', 'couch', 'chair', 'table', 'bed', 'lamp', 'tv', 'cabinet', 'desk', 'stool']
# Prompt creation
if mode == "Remove all furniture":
targets = [label for label in set(labels) if label.lower() in furniture_labels]
prompt = f"Remove all of the following furniture items from the room: {', '.join(targets)}" if targets else "Don't remove anything"
elif mode == "Select items to remove":
if user_items is None or len(user_items) == 0:
return image
targets = [item for item in user_items if item in labels]
if not targets:
return image
if len(targets) == 1:
prompt = f"Remove {targets[0]} from the room"
else:
prompt = "Remove " + ", ".join(targets[:-1]) + " and " + targets[-1] + " from the room"
else:
prompt = "Don't remove anything"
# Run inpainting
result = pipe(
prompt=prompt,
image=img_resized,
guidance_scale=7.0,
num_inference_steps=35
).images[0]
# Resize result
result = result.resize(original_size)
return result
# ───── Gradio Interface ─────
def interface(image, mode, selections):
return detect_and_remove(image, mode, selections)
def get_detected_labels(image):
img_resized = image.resize((512, 512))
res = yolo.predict(source=np.array(img_resized), imgsz=512, device=device, save=False)[0]
labels = [res.names[int(c)] for c in res.boxes.cls]
return list(sorted(set(labels)))
with gr.Blocks() as demo:
gr.Markdown("## πŸͺ„ Furniture Remover using YOLO + FLUX")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Room Image")
mode_input = gr.Radio(["Remove all furniture", "Select items to remove"], label="Mode", value="Remove all furniture")
detected_box = gr.CheckboxGroup(choices=[], label="Select items to remove (detected)", visible=False)
submit = gr.Button("Run")
with gr.Column():
result_image = gr.Image(label="Output Image")
def update_labels(image, mode):
if image is None:
return gr.update(visible=False), gr.update(choices=[])
labels = get_detected_labels(image)
return gr.update(visible=(mode == "Select items to remove"), choices=labels), None
mode_input.change(fn=update_labels, inputs=[image_input, mode_input], outputs=[detected_box, result_image])
image_input.change(fn=update_labels, inputs=[image_input, mode_input], outputs=[detected_box, result_image])
submit.click(fn=interface, inputs=[image_input, mode_input, detected_box], outputs=result_image)
demo.launch()