lycaoduong commited on
Commit
b4b217e
1 Parent(s): 8f4b3c6

Upload 4 files

Browse files
Files changed (4) hide show
  1. FcgEngine/engine.py +95 -0
  2. FcgEngine/utils.py +38 -0
  3. app.py +72 -0
  4. pre-requirements.txt +4 -0
FcgEngine/engine.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import onnxruntime
4
+ import numpy as np
5
+ import cv2
6
+ from transformers import AutoModelForImageClassification, AutoConfig
7
+ import torch
8
+
9
+
10
+ class PredictorClsOnnx(object):
11
+ def __init__(self, model_dir, device='cpu', gpu_allocate=1):
12
+ model_bin = os.path.join(model_dir, 'model.bin')
13
+ with open(os.path.join(model_dir, 'configs.json')) as f:
14
+ configs = json.load(f)
15
+ self.ids = list(configs["ids"])
16
+ self.signal_size = configs["signal_len"]
17
+ if device == 'cuda':
18
+ providers = [
19
+ ('CUDAExecutionProvider', {
20
+ 'device_id': 0,
21
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
22
+ 'gpu_mem_limit': gpu_allocate * 1024 * 1024 * 1024,
23
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
24
+ 'do_copy_in_default_stream': True,
25
+ })
26
+ ]
27
+ session = onnxruntime.InferenceSession(model_bin, None, providers=providers)
28
+ else:
29
+ session = onnxruntime.InferenceSession(model_bin, providers=['CPUExecutionProvider'])
30
+ session.get_modelmeta()
31
+ self.input_name = session.get_inputs()[0].name
32
+ self.output_name = session.get_outputs()[0].name
33
+ self.model = session
34
+ self.device = device
35
+
36
+ def __call__(self, signal):
37
+ signal = np.expand_dims(signal, axis=0)
38
+ max = np.max(signal, axis=1)
39
+ min = np.min(signal, axis=1)
40
+ signal = ((signal.astype(np.float32) - min) / (max - min))
41
+ signal = cv2.resize(signal, (self.signal_size, 1), interpolation=cv2.INTER_CUBIC)
42
+ input_blob = np.expand_dims(signal, axis=0).astype(np.float32)
43
+ outputs = self.model.run([self.output_name], {self.input_name: input_blob})[0][0]
44
+ return outputs
45
+
46
+ def decode(self, result, th=0.5):
47
+ predict_cls = list(np.where(result >= th))[0]
48
+ description = "Signal contains"
49
+ for ids in predict_cls:
50
+ description += ' '
51
+ cls_name = self.ids[ids]
52
+ prob = result[ids]
53
+ description += '{} ({:.4f});'.format(cls_name.capitalize(), prob)
54
+ return description
55
+
56
+
57
+ class PredictorCls(object):
58
+ def __init__(self, model_path='lycaoduong/FcgFormer', device='cpu'):
59
+ self.model = AutoModelForImageClassification.from_pretrained(model_path, trust_remote_code=True)
60
+ self.model.to(device)
61
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
62
+ self.ids = list(config.cls_name.keys())
63
+ self.device = device
64
+
65
+ def __call__(self, spectra):
66
+ tensor = self.model.to_pt_tensor(spectra).to(self.device)
67
+ with torch.no_grad():
68
+ o = self.model(tensor)['logits']
69
+ outputs = torch.sigmoid(o).cpu().numpy()
70
+ return outputs[0]
71
+
72
+ def get_result(self, result, th=0.5, pos_only=False):
73
+ if pos_only:
74
+ predict_cls = list(np.where(result >= th))[0]
75
+ else:
76
+ result[result < th] = 0.0
77
+ predict_cls = list(np.where(result >= 0))[0]
78
+ fcn_groups = []
79
+ probabilities = []
80
+ for ids in predict_cls:
81
+ cls_name = self.ids[ids]
82
+ prob = result[ids]
83
+ fcn_groups.append(cls_name)
84
+ probabilities.append(prob)
85
+ return fcn_groups, probabilities
86
+
87
+ def decode(self, result, th=0.5):
88
+ predict_cls = list(np.where(result >= th))[0]
89
+ description = "Signal contains"
90
+ for ids in predict_cls:
91
+ description += ' '
92
+ cls_name = self.ids[ids]
93
+ prob = result[ids]
94
+ description += '{} ({:.4f});'.format(cls_name.capitalize(), prob)
95
+ return description
FcgEngine/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def plot_self_attention_map(spectra, att_map, offset=400):
6
+ if att_map is not None:
7
+ fig, ax1 = plt.subplots(figsize=(12, 12))
8
+ ax2 = plt.twinx().twiny()
9
+ ax2.set_xlim(offset, offset+len(spectra))
10
+ ax1.set_xlabel('Patch index', fontsize=18)
11
+ ax1.set_ylabel('Patch index', fontsize=18)
12
+ ax1.imshow(att_map[1:, 1:], cmap='inferno', interpolation='nearest', aspect="auto")
13
+ # tem_x = np.zeros_like(spectra)
14
+ x = np.linspace(offset, offset + len(spectra), len(spectra))
15
+ # for i, s in enumerate(spectra):
16
+ # tem_x[i] = i + 400
17
+ ax2.set_xlabel('Wavelength', fontsize=18)
18
+ ax2.set_ylabel('Intensity (a.u.)', fontsize=18)
19
+ ax2.plot(x, spectra)
20
+ return fig
21
+
22
+ def plot_spectra(spectra):
23
+ fig = plt.figure()
24
+ # fg.title("Spectra Signal")
25
+ plt.plot(spectra)
26
+ return fig
27
+
28
+ def plot_result(name, score):
29
+ fig, ax = plt.subplots(figsize=(10, 10))
30
+ y_pos = np.arange(len(name))
31
+ ax.barh(y_pos, score, align='center')
32
+ ax.set_yticks(y_pos, labels=name)
33
+ for value, pos in zip(score, y_pos):
34
+ ax.text(value-0.2, pos, value, ha='right')
35
+ ax.invert_yaxis() # labels read top-to-bottom
36
+ ax.set_xlabel('Confidence')
37
+ ax.set_title('Prediction')
38
+ return fig
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import io
4
+ from FcgEngine.engine import PredictorCls
5
+ from FcgEngine.utils import plot_result, plot_spectra, plot_self_attention_map
6
+ import os
7
+ from huggingface_hub import login
8
+
9
+
10
+ model_path = os.getenv("API_MODEL")
11
+ auth_token = os.getenv("API_TOKEN")
12
+ login(token=auth_token)
13
+ engine = PredictorCls(model_path=model_path)
14
+
15
+
16
+ def process(array_binary, th, option):
17
+ spectra = np.load(io.BytesIO(array_binary[0]))
18
+ outputs = engine(spectra)
19
+ if option == 'Positive only':
20
+ fcn_groups, probabilities = engine.get_result(outputs, th=th, pos_only=True)
21
+ else:
22
+ fcn_groups, probabilities = engine.get_result(outputs, th=th)
23
+ prediction_fg = plot_result(fcn_groups, probabilities)
24
+
25
+ spectra_fg = plot_spectra(spectra)
26
+
27
+ att = engine.model.get_self_attention(layer_value=1)
28
+ att = np.sum(att[0], axis=0)
29
+ attention_fg = plot_self_attention_map(spectra, att)
30
+
31
+ return spectra_fg, prediction_fg, attention_fg
32
+
33
+
34
+ def clear():
35
+ return None, None, None
36
+
37
+
38
+ block = gr.Blocks(title="FcgFormer APP - Ohlabs")
39
+ with block:
40
+ with gr.Row():
41
+ gr.Markdown("## Transformer based for Spectra Characterization")
42
+ with gr.Row():
43
+ with gr.Column():
44
+ with gr.Row():
45
+ spectra_input = gr.File(file_count="multiple",
46
+ label="Select File (*.npy)-signal with shape [1, n], n: signal length",
47
+ type="binary", file_types=['.npy'])
48
+ with gr.Row():
49
+ with gr.Accordion("Advanced options", open=False):
50
+ threshold = gr.Slider(label="Control Confidence", minimum=0.05, maximum=0.95, value=0.5, step=0.05)
51
+ draw_option = gr.Radio(["Positive only", "All classes"], value="Positive only", label="Draw Prediction option")
52
+ with gr.Row():
53
+ with gr.Column():
54
+ run_button = gr.Button(value="Run", variant="primary")
55
+ with gr.Column():
56
+ clear_button = gr.Button(value="Clear")
57
+
58
+ with gr.Column():
59
+ with gr.Tab("Prediction"):
60
+ predicted_plot = gr.Plot(label="Result")
61
+ with gr.Tab("Attention map"):
62
+ attention_plot = gr.Plot(label="Attention Map")
63
+ with gr.Tab("Input Spectra"):
64
+ signal_plot = gr.Plot(label="Spectra Signal")
65
+
66
+ ips = [spectra_input, threshold, draw_option]
67
+ ops = [signal_plot, predicted_plot, attention_plot]
68
+ run_button.click(fn=process, inputs=ips, outputs=ops)
69
+ clear_button.click(fn=clear, inputs=[], outputs=ops)
70
+
71
+ block.launch(server_name='0.0.0.0', share=False)
72
+
pre-requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ onnxruntime-gpu==1.14.1
2
+ transformers==4.37.1
3
+ opencv-python==4.9.0.80
4
+ torch==2.1.2