import os import tempfile import json import numpy as np import gradio as gr import cv2 from drexel_metadata.gen_metadata import gen_metadata from PIL import Image import urllib.request from huggingface_hub import hf_hub_download # Download model if not already cached locally hf_hub_download(repo_id="imageomics/Drexel-metadata-generator", filename="model_final.pth", local_dir="output/enhanced") EXAMPLE_URLS = [ 'http://www.tubri.org/HDR/INHS/INHS_FISH_59422.jpg', 'http://www.tubri.org/HDR/INHS/INHS_FISH_76560.jpg' ] EXAMPLES = [] for example_url in EXAMPLE_URLS: file_name = os.path.basename(example_url) urllib.request.urlretrieve(example_url, file_name) # According to the docs examples should be a nested list EXAMPLES.append([file_name]) def create_temp_file_path(prefix, suffix): with tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix, delete=False) as tmpfile: return tmpfile.name def run_inference(input_img): # input_mg: NumPy array with the shape (width, height, 3) # Save input_mg as a temporary file tmpfile = create_temp_file_path(prefix="input_", suffix=".png") im = Image.fromarray(input_img) im.save(tmpfile) # Create temp filenames for output images visfname = create_temp_file_path(prefix="vis_", suffix=".png") maskfname = create_temp_file_path(prefix="mask_", suffix=".png") # Run inference result = gen_metadata(tmpfile, device='cpu', maskfname=maskfname, visfname=visfname) json_metadata = json.dumps(result) # Cleanup os.remove(tmpfile) return visfname, maskfname, json_metadata def try_remove_preamble(readme_md): # Try to remove the huggingface preamble from README markdown idx = readme_md.find("#") if idx >= 0: return readme_md[idx:] return readme_md def read_app_header_markdown(): with open('README.md') as infile: return try_remove_preamble(infile.read()) dm_app = gr.Interface( description=read_app_header_markdown(), fn=run_inference, # Input shows markdown explaining and app and a single image upload panel inputs=[ gr.Image() ], # Output consists of a visualization image, a masked image, and JSON metadata outputs=[ gr.Image(label='visualization'), gr.Image(label='mask'), gr.JSON(label="JSON metadata") ], allow_flagging="never", # Do not save user's results or prompt for users to save the results examples=EXAMPLES, ) dm_app.launch(server_name="0.0.0.0")