Spaces:
Running
Running
File size: 5,896 Bytes
fdd10b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
from gradio_client import Client, handle_file
import re
import time
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Get Hugging Face token from environment variable
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# Initialize client with auth
client = Client(
"levihsu/OOTDiffusion",
hf_token=hf_token
)
def generate_outfit(model_image, garment_image, n_samples=1, n_steps=20, image_scale=2, seed=-1):
if model_image is None or garment_image is None:
return None, "Please upload both model and garment images"
max_retries = 3
for attempt in range(max_retries):
try:
# Use the client to predict
result = client.predict(
vton_img=handle_file(model_image),
garm_img=handle_file(garment_image),
n_samples=n_samples,
n_steps=n_steps,
image_scale=image_scale,
seed=seed,
api_name="/process_hd"
)
# If result is a list, get the first item
if isinstance(result, list):
result = result[0]
# If result is a dictionary, try to get the image path
if isinstance(result, dict):
if 'image' in result:
return result['image'], None
else:
return None, "API returned unexpected format"
return result, None
except Exception as e:
error_msg = str(e)
if "exceeded your GPU quota" in error_msg:
wait_time_match = re.search(r'retry in (\d+:\d+:\d+)', error_msg)
wait_time = wait_time_match.group(1) if wait_time_match else "60:00" # Default to 1 hour
wait_seconds = sum(int(x) * 60 ** i for i, x in enumerate(reversed(wait_time.split(':')))) # Convert wait time to seconds
if attempt < max_retries - 1:
time.sleep(wait_seconds) # Wait before retrying
return None, f"GPU quota exceeded. Please wait {wait_time} before trying again."
else:
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
## Outfit Diffusion - Try On Virtual Outfits
⚠️ **Note**: This demo uses free GPU quota which is limited. To avoid errors:
- Use lower values for Steps (10-15) and Scale (1-2)
- Wait between attempts if you get a quota error
- Sign up for a Hugging Face account for more quota
""")
with gr.Row():
with gr.Column():
model_image = gr.Image(
label="Upload Model Image (person wearing clothes)",
type="filepath",
height=300
)
model_examples = [
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/ba5ba7978e7302e8ab5eb733cc7221394c4e6faf/model_5.png",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/40dade4a04a827c0fdf63c6c70b42ef26480f391/01861_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/3c4639c5fab3cdcd3239609dca5afee7b0677286/model_6.png",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/0089171df270f4532eec3d80a8f36cc8218c6840/01008_00.jpg"
]
gr.Examples(examples=model_examples, inputs=model_image)
garment_image = gr.Image(
label="Upload Garment Image (clothing item)",
type="filepath",
height=300
)
garment_examples = [
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/180d4e2a1139071a8685a5edee7ab24bcf1639f5/03244_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/584dda2c5ee1d8271a6cd06225c07db89c79ca03/04825_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/a51938ec99f13e548d365a9ca6d794b6fe7462af/049949_1.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/2d64241101189251ce415df84dc9205cda9a36ca/03032_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/44aee6b576cae51eeb979311306375b56b7e0d8b/02305_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/578dfa869dedb649e91eccbe566fc76435bb6bbe/049920_1.jpg"
]
gr.Examples(examples=garment_examples, inputs=garment_image)
with gr.Column():
output_image = gr.Image(label="Generated Output")
error_text = gr.Markdown() # Add error display
with gr.Row():
with gr.Column():
n_samples = gr.Slider(
label="Number of Samples",
minimum=1,
maximum=5,
step=1,
value=1
)
n_steps = gr.Slider(
label="Steps (lower = faster, try 10-15)",
minimum=1,
maximum=50,
step=1,
value=10 # Reduced default
)
image_scale = gr.Slider(
label="Scale (lower = faster, try 1-2)",
minimum=1,
maximum=5,
step=1,
value=1 # Reduced default
)
seed = gr.Number(
label="Random Seed (-1 for random)",
value=-1
)
generate_button = gr.Button("Generate Outfit")
# Set up the action for the button
generate_button.click(
fn=generate_outfit,
inputs=[model_image, garment_image, n_samples, n_steps, image_scale, seed],
outputs=[output_image, error_text]
)
# Launch the app
demo.launch()
|