Epsilon617 commited on
Commit
cb013a1
1 Parent(s): 479afcd

reformatting the outputs as dataframe

Browse files
Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
app.py CHANGED
@@ -10,6 +10,9 @@ import logging
10
 
11
  import json
12
  import os
 
 
 
13
 
14
  import importlib
15
  modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")
@@ -36,8 +39,7 @@ inputs = [
36
  live_inputs = [
37
  gr.Audio(source="microphone",streaming=True, type="filepath"),
38
  ]
39
- # outputs = [gr.components.Textbox()]
40
- # outputs = [gr.components.Textbox(), transcription_df]
41
  title = "One Model for All Music Understanding Tasks"
42
  description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal represenation."
43
  article = "The tasks include EMO, GS, MTGInstrument, MTGGenre, MTGTop50, MTGMood, NSynthI, NSynthP, VocalSetS, VocalSetT. \n\n More models can be referred at the [map organization page](https://huggingface.co/m-a-p)."
@@ -46,6 +48,17 @@ audio_examples = [
46
  # ["input/example-2.wav"],
47
  ]
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Load the model and the corresponding preprocessor config
50
  # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
51
  # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
@@ -105,6 +118,7 @@ for task in TASKS:
105
 
106
  model.to(device)
107
 
 
108
  def model_infernce(inputs):
109
  waveform, sample_rate = torchaudio.load(inputs)
110
 
@@ -112,7 +126,7 @@ def model_infernce(inputs):
112
 
113
  # make sure the sample_rate aligned
114
  if resample_rate != sample_rate:
115
- print(f'setting rate from {sample_rate} to {resample_rate}')
116
  resampler = T.Resample(sample_rate, resample_rate)
117
  waveform = resampler(waveform)
118
 
@@ -129,13 +143,16 @@ def model_infernce(inputs):
129
  all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
130
 
131
  task_output_texts = ""
 
 
 
132
  for task in TASKS:
133
  num_class = len(ID2CLASS[task].keys())
134
  if MERT_BEST_LAYER_IDX[task] == 'all':
135
  logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87]
136
  else:
137
  logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
138
- print(f'task {task} logits:', logits.shape, 'num class:', num_class)
139
 
140
  sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1
141
  sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
@@ -145,33 +162,40 @@ def model_infernce(inputs):
145
  top_n_show = 3 if num_class >= 3 else num_class
146
  task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
147
  task_output_texts = task_output_texts + '----------------------\n'
148
- # output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
149
- # logger.warning(all_layer_hidden_states.shape)
150
-
151
- # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
152
- # return f"device: {device}\n" + output_texts
153
- return task_output_texts
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def convert_audio(inputs, microphone):
156
  if (microphone is not None):
157
  inputs = microphone
 
 
158
 
159
- text = model_infernce(inputs)
160
-
161
- return text
162
-
163
  def live_convert_audio(microphone):
164
  if (microphone is not None):
165
  inputs = microphone
166
-
167
- text = model_infernce(inputs)
168
-
169
- return text
170
 
171
  audio_chunked = gr.Interface(
172
  fn=convert_audio,
173
  inputs=inputs,
174
- outputs=[gr.components.Textbox()],
175
  allow_flagging="never",
176
  title=title,
177
  description=description,
@@ -182,7 +206,7 @@ audio_chunked = gr.Interface(
182
  live_audio_chunked = gr.Interface(
183
  fn=live_convert_audio,
184
  inputs=live_inputs,
185
- outputs=[gr.components.Textbox()],
186
  allow_flagging="never",
187
  title=title,
188
  description=description,
@@ -204,5 +228,5 @@ with demo:
204
  "Live Streaming Music"
205
  ]
206
  )
207
- demo.queue(concurrency_count=1, max_size=5)
208
  demo.launch(show_api=False)
 
10
 
11
  import json
12
  import os
13
+ import re
14
+
15
+ import pandas as pd
16
 
17
  import importlib
18
  modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")
 
39
  live_inputs = [
40
  gr.Audio(source="microphone",streaming=True, type="filepath"),
41
  ]
42
+
 
43
  title = "One Model for All Music Understanding Tasks"
44
  description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal represenation."
45
  article = "The tasks include EMO, GS, MTGInstrument, MTGGenre, MTGTop50, MTGMood, NSynthI, NSynthP, VocalSetS, VocalSetT. \n\n More models can be referred at the [map organization page](https://huggingface.co/m-a-p)."
 
48
  # ["input/example-2.wav"],
49
  ]
50
 
51
+ df_init = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3'])
52
+ transcription_df = gr.DataFrame(value=df_init, label="Output Dataframe", row_count=(
53
+ 0, "dynamic"), max_rows=30, wrap=True, overflow_row_behaviour='paginate')
54
+ # outputs = [gr.components.Textbox()]
55
+ outputs = [ transcription_df]
56
+
57
+ df_init_live = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3'])
58
+ transcription_df_live = gr.DataFrame(value=df_init_live, label="Output Dataframe", row_count=(
59
+ 0, "dynamic"), max_rows=30, wrap=True, overflow_row_behaviour='paginate')
60
+ outputs_live = [transcription_df_live]
61
+
62
  # Load the model and the corresponding preprocessor config
63
  # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
64
  # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
 
118
 
119
  model.to(device)
120
 
121
+
122
  def model_infernce(inputs):
123
  waveform, sample_rate = torchaudio.load(inputs)
124
 
 
126
 
127
  # make sure the sample_rate aligned
128
  if resample_rate != sample_rate:
129
+ # print(f'setting rate from {sample_rate} to {resample_rate}')
130
  resampler = T.Resample(sample_rate, resample_rate)
131
  waveform = resampler(waveform)
132
 
 
143
  all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
144
 
145
  task_output_texts = ""
146
+ df = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3'])
147
+ df_objects = []
148
+
149
  for task in TASKS:
150
  num_class = len(ID2CLASS[task].keys())
151
  if MERT_BEST_LAYER_IDX[task] == 'all':
152
  logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87]
153
  else:
154
  logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
155
+ # print(f'task {task} logits:', logits.shape, 'num class:', num_class)
156
 
157
  sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1
158
  sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
 
162
  top_n_show = 3 if num_class >= 3 else num_class
163
  task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
164
  task_output_texts = task_output_texts + '----------------------\n'
165
+
166
+ row_elements = [task]
167
+ for idx in range(top_n_show):
168
+ print(ID2CLASS[task])
169
+ # print('id', str(sorted_idx[idx].item()))
170
+ output_class_name = str(ID2CLASS[task][str(sorted_idx[idx].item())])
171
+ output_class_name = re.sub(r'^\w+---', '', output_class_name)
172
+ output_class_name = re.sub(r'^\w+\/\w+---', '', output_class_name)
173
+ # print('output name', output_class_name)
174
+ output_prob = f' {sorted_prob[idx].item():.2%}'
175
+ row_elements.append(output_class_name+output_prob)
176
+ # fill empty elment
177
+ for _ in range(4 - len(row_elements)):
178
+ row_elements.append(' ')
179
+ df_objects.append(row_elements)
180
+ df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3'])
181
+ return df
182
 
183
  def convert_audio(inputs, microphone):
184
  if (microphone is not None):
185
  inputs = microphone
186
+ df = model_infernce(inputs)
187
+ return df
188
 
 
 
 
 
189
  def live_convert_audio(microphone):
190
  if (microphone is not None):
191
  inputs = microphone
192
+ df = model_infernce(inputs)
193
+ return df
 
 
194
 
195
  audio_chunked = gr.Interface(
196
  fn=convert_audio,
197
  inputs=inputs,
198
+ outputs=outputs,
199
  allow_flagging="never",
200
  title=title,
201
  description=description,
 
206
  live_audio_chunked = gr.Interface(
207
  fn=live_convert_audio,
208
  inputs=live_inputs,
209
+ outputs=outputs_live,
210
  allow_flagging="never",
211
  title=title,
212
  description=description,
 
228
  "Live Streaming Music"
229
  ]
230
  )
231
+ # demo.queue(concurrency_count=1, max_size=5)
232
  demo.launch(show_api=False)