juancopi81 commited on
Commit
04107a3
1 Parent(s): fa0462c

Change inputs to list

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. utils.py +10 -4
app.py CHANGED
@@ -118,7 +118,7 @@ with demo:
118
  midi_audio = gr.Audio()
119
 
120
  btn.click(inference,
121
- inputs="final_audio.wav",
122
  outputs=[midi_file, midi_audio])
123
 
124
  gr.Markdown(article)
 
118
  midi_audio = gr.Audio()
119
 
120
  btn.click(inference,
121
+ inputs=["final_audio.wav"],
122
  outputs=[midi_file, midi_audio])
123
 
124
  gr.Markdown(article)
utils.py CHANGED
@@ -1,14 +1,13 @@
1
 
2
- import tempfile
3
  import collections
4
 
5
- import note_seq
6
-
7
  import pandas as pd
 
8
  import matplotlib.pyplot as plt
9
  from matplotlib.patches import Rectangle
10
  from PIL import Image
11
 
 
12
  class AudioIOReadError(BaseException): # pylint:disable=g-bad-exception-name
13
  pass
14
 
@@ -31,8 +30,15 @@ def dataframe_to_pianoroll_img(df):
31
  fig = plt.figure(figsize=(8, 5))
32
  ax = fig.add_subplot(111)
33
  ax.scatter(df.start_time, df.pitch, c="white")
 
 
 
34
  for _, row in df.iterrows():
35
- ax.add_patch(Rectangle((row["start_time"], row["pitch"]-0.4), row["duration"], 0.4, color="black"))
 
 
 
 
36
  plt.xlabel('time (sec.)', fontsize=18)
37
  plt.ylabel('pitch (MIDI)', fontsize=16)
38
  return fig
 
1
 
 
2
  import collections
3
 
 
 
4
  import pandas as pd
5
+ import plotly.express as px
6
  import matplotlib.pyplot as plt
7
  from matplotlib.patches import Rectangle
8
  from PIL import Image
9
 
10
+ import note_seq
11
  class AudioIOReadError(BaseException): # pylint:disable=g-bad-exception-name
12
  pass
13
 
 
30
  fig = plt.figure(figsize=(8, 5))
31
  ax = fig.add_subplot(111)
32
  ax.scatter(df.start_time, df.pitch, c="white")
33
+ n_colors = len(df.instrument.unique())
34
+ colorscale = px.colors.sample_colorscale("viridis", [n/(n_colors -1) for n in range(n_colors)])
35
+ colordict = {f:colorscale[i] for i, f in enumerate(df.instrument.unique())}
36
  for _, row in df.iterrows():
37
+ ax.add_patch(Rectangle((row["start_time"],
38
+ row["pitch"]-0.4),
39
+ row["duration"],
40
+ 0.4,
41
+ color=colordict[row["instrument"]]))
42
  plt.xlabel('time (sec.)', fontsize=18)
43
  plt.ylabel('pitch (MIDI)', fontsize=16)
44
  return fig