johnbradley's picture
Cleanup readme and app header
4ab5f2f
raw
history blame
No virus
2.53 kB
import os
import tempfile
import json
import numpy as np
import gradio as gr
import cv2
from drexel_metadata.gen_metadata import gen_metadata
from PIL import Image
import urllib.request
from huggingface_hub import hf_hub_download
# Download model if not already cached locally
hf_hub_download(repo_id="imageomics/Drexel-metadata-generator", filename="model_final.pth", local_dir="output/enhanced")
EXAMPLE_URLS = [
'http://www.tubri.org/HDR/INHS/INHS_FISH_59422.jpg',
'http://www.tubri.org/HDR/INHS/INHS_FISH_76560.jpg'
]
EXAMPLES = []
for example_url in EXAMPLE_URLS:
file_name = os.path.basename(example_url)
urllib.request.urlretrieve(example_url, file_name)
# According to the docs examples should be a nested list
EXAMPLES.append([file_name])
def create_temp_file_path(prefix, suffix):
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix, delete=False) as tmpfile:
return tmpfile.name
def run_inference(input_img):
# input_mg: NumPy array with the shape (width, height, 3)
# Save input_mg as a temporary file
tmpfile = create_temp_file_path(prefix="input_", suffix=".png")
im = Image.fromarray(input_img)
im.save(tmpfile)
# Create temp filenames for output images
visfname = create_temp_file_path(prefix="vis_", suffix=".png")
maskfname = create_temp_file_path(prefix="mask_", suffix=".png")
# Run inference
result = gen_metadata(tmpfile, device='cpu', maskfname=maskfname, visfname=visfname)
json_metadata = json.dumps(result)
# Cleanup
os.remove(tmpfile)
return visfname, maskfname, json_metadata
def try_remove_preamble(readme_md):
# Try to remove the huggingface preamble from README markdown
idx = readme_md.find("#")
if idx >= 0:
return readme_md[idx:]
return readme_md
def read_app_header_markdown():
with open('README.md') as infile:
return try_remove_preamble(infile.read())
dm_app = gr.Interface(
description=read_app_header_markdown(),
fn=run_inference,
# Input shows markdown explaining and app and a single image upload panel
inputs=[
gr.Image()
],
# Output consists of a visualization image, a masked image, and JSON metadata
outputs=[
gr.Image(label='visualization'),
gr.Image(label='mask'),
gr.JSON(label="JSON metadata")
],
allow_flagging="never", # Do not save user's results or prompt for users to save the results
examples=EXAMPLES,
)
dm_app.launch(server_name="0.0.0.0")