juancopi81 commited on
Commit
9f994e9
1 Parent(s): ae31dd7

Change plot function

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -2
  2. utils.py +3 -5
requirements.txt CHANGED
@@ -9,5 +9,4 @@ clu==0.0.7
9
  # pin Orbax to use Checkpointer
10
  orbax==0.0.2
11
  pytube
12
- pydub
13
- plotly==5.11.0
 
9
  # pin Orbax to use Checkpointer
10
  orbax==0.0.2
11
  pytube
12
+ pydub
 
utils.py CHANGED
@@ -2,7 +2,6 @@
2
  import collections
3
 
4
  import pandas as pd
5
- import plotly.express as px
6
  import matplotlib.pyplot as plt
7
  from matplotlib.patches import Rectangle
8
  from PIL import Image
@@ -30,10 +29,9 @@ def dataframe_to_pianoroll_img(df):
30
  fig = plt.figure(figsize=(8, 5))
31
  ax = fig.add_subplot(111)
32
  ax.scatter(df.start_time, df.pitch, c="white")
33
- n_colors = len(df.instrument.unique())
34
- colorscale = px.colors.sample_colorscale("viridis", [n/(n_colors) for n in range(n_colors)])
35
- print("colorscale", colorscale)
36
- colordict = {f:colorscale[i] for i, f in enumerate(df.instrument.unique())}
37
  for _, row in df.iterrows():
38
  ax.add_patch(Rectangle((row["start_time"],
39
  row["pitch"]-0.4),
 
2
  import collections
3
 
4
  import pandas as pd
 
5
  import matplotlib.pyplot as plt
6
  from matplotlib.patches import Rectangle
7
  from PIL import Image
 
29
  fig = plt.figure(figsize=(8, 5))
30
  ax = fig.add_subplot(111)
31
  ax.scatter(df.start_time, df.pitch, c="white")
32
+ colors = plt.get_cmap("tab20")
33
+ print(df["instrument"].unique())
34
+ colordict = {inst: colors[i] for i, inst in enumerate(df["instrument"].unique())}
 
35
  for _, row in df.iterrows():
36
  ax.add_patch(Rectangle((row["start_time"],
37
  row["pitch"]-0.4),