Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
from gradio_app.inference import run_inference | |
from gradio_app.components import ( | |
CONTENT_DESCRIPTION, CONTENT_OUTTRO, | |
CONTENT_IN_1, CONTENT_IN_2, | |
CONTENT_OUT_1, CONTENT_OUT_2, | |
list_reference_files, list_mapping_files, | |
list_classifier_files, list_edgeface_files | |
) | |
from glob import glob | |
import os | |
def create_image_io_row(): | |
"""Create the row for image input and output display.""" | |
with gr.Row(elem_classes=["image-io-row"]): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
output = gr.HTML(label="Inference Results", elem_classes=["results-container"]) | |
return image_input, output | |
def create_model_settings_row(): | |
"""Create the row for model files and settings.""" | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(elem_classes=["section-group"]): | |
gr.Markdown("### Model Files", elem_classes=["section-title"]) | |
ref_dict = gr.Dropdown( | |
choices=["Select a file"] + list_reference_files(), | |
label="Reference Dict JSON", | |
value="data/reference_data/reference_image_data.json" | |
) | |
index_map = gr.Dropdown( | |
choices=["Select a file"] + list_mapping_files(), | |
label="Index to Class Mapping JSON", | |
value="ckpts/index_to_class_mapping.json" | |
) | |
classifier_model = gr.Dropdown( | |
choices=["Select a file"] + list_classifier_files(), | |
label="Classifier Model (.pth)", | |
value="ckpts/SlimFace_efficientnet_b3_full_model.pth" | |
) | |
edgeface_model = gr.Dropdown( | |
choices=["Select a file"] + list_edgeface_files(), | |
label="EdgeFace Model (.pt)", | |
value="ckpts/idiap/edgeface_s_gamma_05.pt" | |
) | |
with gr.Column(): | |
with gr.Group(elem_classes=["section-group"]): | |
gr.Markdown("### Advanced Settings", elem_classes=["section-title"]) | |
algorithm = gr.Dropdown( | |
choices=["yolo", "mtcnn", "retinaface"], | |
label="Detection Algorithm", | |
value="yolo" | |
) | |
accelerator = gr.Dropdown( | |
choices=["auto", "cpu", "cuda", "mps"], | |
label="Accelerator", | |
value="auto" | |
) | |
resolution = gr.Slider( | |
minimum=128, | |
maximum=512, | |
step=32, | |
label="Image Resolution", | |
value=300 | |
) | |
similarity_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
step=0.05, | |
label="Similarity Threshold", | |
value=0.3 | |
) | |
return ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold | |
# Load local CSS file | |
CSS = open("apps/gradio_app/static/styles.css").read() | |
def create_interface(): | |
"""Create the Gradio interface for SlimFace.""" | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# SlimFace Demonstration") | |
gr.Markdown(CONTENT_DESCRIPTION) | |
gr.Markdown(CONTENT_IN_1) | |
gr.HTML(CONTENT_IN_2) | |
image_input, output = create_image_io_row() | |
ref_dict, index_map, classifier_model, edgeface_model, algorithm, accelerator, resolution, similarity_threshold = create_model_settings_row() | |
# Add example image gallery as a row of columns | |
with gr.Group(): | |
gr.Markdown("### Example Images") | |
example_images = glob("apps/assets/examples/*.[jp][pn][gf]") | |
if example_images: | |
with gr.Row(elem_classes=["example-row"]): | |
for img_path in example_images: | |
with gr.Column(min_width=120): | |
gr.Image( | |
value=img_path, | |
label=os.path.basename(img_path), | |
type="filepath", | |
height=100, | |
elem_classes=["example-image"] | |
) | |
gr.Button(f"Use {os.path.basename(img_path)}").click( | |
fn=lambda x=img_path: Image.open(x), | |
outputs=image_input | |
) | |
else: | |
gr.Markdown("No example images found in apps/assets/examples/") | |
with gr.Row(): | |
submit_btn = gr.Button("Run Inference", variant="primary", elem_classes=["centered-button"]) | |
submit_btn.click( | |
fn=run_inference, | |
inputs=[ | |
image_input, | |
ref_dict, | |
index_map, | |
classifier_model, | |
edgeface_model, | |
algorithm, | |
accelerator, | |
resolution, | |
similarity_threshold | |
], | |
outputs=output | |
) | |
gr.Markdown(CONTENT_OUTTRO) | |
gr.HTML(CONTENT_OUT_1) | |
gr.Markdown(CONTENT_OUT_2) | |
return demo | |
def main(): | |
"""Launch the Gradio interface.""" | |
demo = create_interface() | |
demo.launch() | |
if __name__ == "__main__": | |
main() |