weiren119's picture
Feat: append json result rendering and new examples
c77178a
"""
Original Algorithm:
- https://github.com/GreenCUBIC/AudiogramDigitization
Source:
- huggingface app
- https://huggingface.co/spaces/aravinds1811/neural-style-transfer/blob/main/app.py
- https://huggingface.co/spaces/keras-io/ocr-for-captcha/blob/main/app.py
- https://huggingface.co/spaces/hugginglearners/image-style-transfer/blob/main/app.py
- https://tmabraham.github.io/blog/gradio_hf_spaces_tutorial
- huggingface push
- https://huggingface.co/welcome
"""
import os
import sys
from pathlib import Path
from PIL import Image
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import gradio as gr
sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
from digitizer.digitization import generate_partial_annotation, extract_thresholds
EXAMPLES_PATH = Path('./examples')
max_length = 5
img_width = 200
img_height = 50
def load_image(path, zoom=1):
return OffsetImage(plt.imread(path), zoom=zoom)
def plot_audiogram(digital_result):
thresholds = pd.DataFrame(digital_result)
# Figure
fig = plt.figure()
ax = fig.add_subplot(111)
# x axis
axis = [250, 500, 1000, 2000, 4000, 8000, 16000]
ax.set_xscale('log')
ax.xaxis.tick_top()
ax.xaxis.set_major_formatter(plt.FuncFormatter('{:.0f}'.format))
ax.set_xlabel('Frequency (Hz)')
ax.xaxis.set_label_position('top')
ax.set_xlim(125,16000)
plt.xticks(axis)
# y axis
ax.set_ylim(-20, 120)
ax.invert_yaxis()
ax.set_ylabel('Threshold (dB HL)')
plt.grid()
for conduction in ("air", "bone"):
for masking in (True, False):
for ear in ("left", "right"):
symbol_name = f"{ear}_{conduction}_{'unmasked' if not masking else 'masked'}"
selection = thresholds[(thresholds.conduction == conduction) & (thresholds.ear == ear) & (thresholds.masking == masking)]
selection = selection.sort_values("frequency")
# Plot the symbols
for i, threshold in selection.iterrows():
ab = AnnotationBbox(load_image(f"src/digitizer/assets/symbols/{symbol_name}.png", zoom=0.1), (threshold.frequency, threshold.threshold), frameon=False)
ax.add_artist(ab)
# Add joining line for air conduction thresholds
if conduction == "air":
plt.plot(selection.frequency, selection.threshold, color="red" if ear == "right" else "blue", linewidth=0.5)
return plt.gcf()
# Function for Audiogram Digit Recognition
def audiogram_digit_recognition(img_path):
digital_result = extract_thresholds(img_path, gpu=False)
return [plot_audiogram(digital_result), digital_result]
output = [gr.Plot(), gr.JSON()]
examples = [
f'{EXAMPLES_PATH}/audiogram_example01.png',
f'{EXAMPLES_PATH}/audiogram_example02.png'
]
iface = gr.Interface(
fn=audiogram_digit_recognition,
inputs = gr.inputs.Image(type='filepath'),
outputs = output , #"image",
title=" AudiogramDigitization",
description = "facilitate the digitization of audiology reports based on pytorch",
article = "Algorithm Authors: <a href=\"francoischarih@sce.carleton.ca\">Francois Charih \
and <a href=\"jrgreen@sce.carleton.ca\"> James R. Green </a>. \
Based on the AudiogramDigitization <a href=\"https://github.com/GreenCUBIC/AudiogramDigitization\">github repo</a>",
examples = examples,
allow_flagging='never',
cache_examples=False,
)
iface.launch(
enable_queue=True, debug=False, inbrowser=False
)