flux-onnx / app.py
nimraaaajhduksy's picture
Update app.py
131839e verified
import gradio as gr
from PIL import Image
import numpy as np
import torch
import os
from diffusers import FluxKontextPipeline # Fixed: Import the correct pipeline
from diffusers.utils import load_image
from huggingface_hub import HfApi, login
from huggingface_hub.utils import HfHubHTTPError
from ultralytics import YOLO
# HuggingFace token setup
import os
# Option 1: Set your token directly (not recommended for production)
# token = "hf_your_token_here" # Replace with your actual token
# Option 2: Use environment variable (recommended)
token = os.getenv("HF_TOKEN") # Changed to match your secret name
# Option 3: Skip login if no token (will work for public models)
if token:
try:
login(token=token)
api = HfApi(token=token)
user = api.whoami()
print("βœ… HuggingFace token valid. Logged in as:", user["name"])
except HfHubHTTPError as e:
print("❌ Invalid or expired HuggingFace token.")
print("Error:", e)
else:
print("⚠️ No HuggingFace token found. Using public access only.")
api = HfApi()
# ───── Load FLUX-Kontext Pipeline ─────
device = "cuda" if torch.cuda.is_available() else "cpu"
# Try the quantized version first, fallback to official if needed
try:
pipe = FluxKontextPipeline.from_pretrained(
"HighCWu/FLUX.1-Kontext-dev-bnb-hqq-4bit",
torch_dtype=torch.bfloat16,
)
print("βœ… FLUX-Kontext quantized pipeline loaded")
except Exception as e:
print(f"⚠️ Quantized model failed: {e}")
print("πŸ“₯ Falling back to official model...")
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16,
)
print("βœ… FLUX-Kontext official pipeline loaded")
pipe.to(device)
# ───── Load YOLOv8 Segmentation Model ─────
yolo = YOLO('yolov8s-seg.pt') # can be replaced with your own weights
print("βœ… YOLOv8s-seg loaded")
# ───── Valid Items to Detect ─────
valid_room_items = [
'sofa', 'couch', 'chair', 'table', 'bed', 'lamp', 'tv', 'cabinet', 'desk',
'stool', 'curtain', 'carpet', 'painting', 'mirror', 'shelf',
'pillow', 'cushion', 'potted plant', 'plant', 'vase', 'rug', 'bowl', 'book'
]
# ───── Helper Functions ─────
def run_yolo(image):
img_resized = image.resize((512, 512))
result = yolo.predict(source=np.array(img_resized), imgsz=512, device=device, save=False)[0]
all_labels = [result.names[int(c)] for c in result.boxes.cls]
detected_labels = sorted(set(label.lower().rstrip('s') for label in all_labels))
filtered = [item for item in detected_labels if item in valid_room_items]
return list(set(filtered))
def generate_prompt(mode, selections=None):
if mode == 'all':
return "Remove everything from the room including all furniture like bed, sofa, couch, table, lamp, chairs, curtains, decor items, etc., except the walls and carpet."
elif mode == 'select':
if not selections:
return "Don't remove anything"
elif len(selections) == 1:
return f"Remove {selections[0]} from the room"
else:
return "Remove " + ", ".join(selections[:-1]) + " and " + selections[-1] + " from the room"
else:
raise ValueError("Invalid mode")
def process(image, mode, selections):
if image is None:
return None
original_size = image.size
img_resized = image.resize((512, 512))
prompt = generate_prompt(mode, selections)
print("🧠 Auto-prompt:", prompt)
try:
result = pipe(
prompt=prompt,
image=img_resized,
guidance_scale=7.0,
num_inference_steps=35
).images[0]
result_upscaled = result.resize(original_size)
return result_upscaled
except Exception as e:
print(f"Error during processing: {e}")
return None
def interface_main(image, mode):
if image is None:
return gr.update(visible=False), gr.update(visible=True)
if mode == 'select':
detected = run_yolo(image)
return gr.update(visible=True, choices=detected, value=[]), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=True)
# ───── Gradio UI ─────
with gr.Blocks() as demo:
gr.Markdown("## πŸ›‹οΈ Room Object Remover using FLUX + YOLOv8")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Room Image", type="pil")
mode = gr.Radio(choices=["all", "select"], label="Choose Mode", value="all")
detected_items = gr.CheckboxGroup(choices=[], label="Select objects to remove", visible=False)
run_button = gr.Button("Run")
with gr.Column():
output_image = gr.Image(label="Output Image", interactive=False)
mode.change(fn=interface_main, inputs=[image_input, mode], outputs=[detected_items, run_button])
run_button.click(fn=process, inputs=[image_input, mode, detected_items], outputs=[output_image])
if __name__ == "__main__":
demo.launch()