FcgFormerApp / app.py
lycaoduong's picture
Upload 4 files
b4b217e verified
raw
history blame
No virus
2.68 kB
import numpy as np
import gradio as gr
import io
from FcgEngine.engine import PredictorCls
from FcgEngine.utils import plot_result, plot_spectra, plot_self_attention_map
import os
from huggingface_hub import login
model_path = os.getenv("API_MODEL")
auth_token = os.getenv("API_TOKEN")
login(token=auth_token)
engine = PredictorCls(model_path=model_path)
def process(array_binary, th, option):
spectra = np.load(io.BytesIO(array_binary[0]))
outputs = engine(spectra)
if option == 'Positive only':
fcn_groups, probabilities = engine.get_result(outputs, th=th, pos_only=True)
else:
fcn_groups, probabilities = engine.get_result(outputs, th=th)
prediction_fg = plot_result(fcn_groups, probabilities)
spectra_fg = plot_spectra(spectra)
att = engine.model.get_self_attention(layer_value=1)
att = np.sum(att[0], axis=0)
attention_fg = plot_self_attention_map(spectra, att)
return spectra_fg, prediction_fg, attention_fg
def clear():
return None, None, None
block = gr.Blocks(title="FcgFormer APP - Ohlabs")
with block:
with gr.Row():
gr.Markdown("## Transformer based for Spectra Characterization")
with gr.Row():
with gr.Column():
with gr.Row():
spectra_input = gr.File(file_count="multiple",
label="Select File (*.npy)-signal with shape [1, n], n: signal length",
type="binary", file_types=['.npy'])
with gr.Row():
with gr.Accordion("Advanced options", open=False):
threshold = gr.Slider(label="Control Confidence", minimum=0.05, maximum=0.95, value=0.5, step=0.05)
draw_option = gr.Radio(["Positive only", "All classes"], value="Positive only", label="Draw Prediction option")
with gr.Row():
with gr.Column():
run_button = gr.Button(value="Run", variant="primary")
with gr.Column():
clear_button = gr.Button(value="Clear")
with gr.Column():
with gr.Tab("Prediction"):
predicted_plot = gr.Plot(label="Result")
with gr.Tab("Attention map"):
attention_plot = gr.Plot(label="Attention Map")
with gr.Tab("Input Spectra"):
signal_plot = gr.Plot(label="Spectra Signal")
ips = [spectra_input, threshold, draw_option]
ops = [signal_plot, predicted_plot, attention_plot]
run_button.click(fn=process, inputs=ips, outputs=ops)
clear_button.click(fn=clear, inputs=[], outputs=ops)
block.launch(server_name='0.0.0.0', share=False)