joaogante HF staff commited on
Commit
2723972
1 Parent(s): 31b8889

playing around with gradio

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import functools
2
- import matplotlib
3
- matplotlib.use('Agg')
4
- import matplotlib.pyplot as plt
5
 
6
  import gradio as gr
 
 
7
 
8
 
 
9
  BENCHMARK_DATA = {
10
  "Greedy Search": {
11
  "DistilGPT2": {
@@ -29,9 +29,9 @@ BENCHMARK_DATA = {
29
  "A100": [],
30
  },
31
  "T5 Small": {
32
- "T4": [1, 2, 3, 4],
33
- "3090": [],
34
- "A100": [],
35
  },
36
  "T5 Base": {
37
  "T4": [],
@@ -137,10 +137,19 @@ BENCHMARK_DATA = {
137
 
138
 
139
  def get_plot(model_name, generate_type):
140
- data = BENCHMARK_DATA[generate_type][model_name]["T4"]
141
- plt.plot(data)
142
- plt.title(model_name)
143
- return plt.gcf()
 
 
 
 
 
 
 
 
 
144
 
145
  demo = gr.Blocks()
146
 
 
1
  import functools
 
 
 
2
 
3
  import gradio as gr
4
+ import seaborn as sns
5
+ import pandas as pd
6
 
7
 
8
+ # benchmark order: pytorch, tf eager, tf xla; units = ms
9
  BENCHMARK_DATA = {
10
  "Greedy Search": {
11
  "DistilGPT2": {
 
29
  "A100": [],
30
  },
31
  "T5 Small": {
32
+ "T4": [99.88, 1527.73, 18.78],
33
+ "3090": [55.09, 665.70, 9.25],
34
+ "A100": [124.91, 1642.07, 13.72],
35
  },
36
  "T5 Base": {
37
  "T4": [],
 
137
 
138
 
139
  def get_plot(model_name, generate_type):
140
+ df = pd.DataFrame(BENCHMARK_DATA[generate_type][model_name])
141
+ df["framework"] = ["PyTorch", "TF (Eager Execition)", "TF (XLA)"]
142
+ df = pd.melt(df, id_vars=["framework"], value_vars=["T4", "3090", "A100"])
143
+
144
+ g = sns.catplot(
145
+ data=df, kind="bar",
146
+ x="variable", y="value", hue="framework",
147
+ ci="sd", palette="dark", alpha=.6, height=6
148
+ )
149
+ g.despine(left=True)
150
+ # g.set_axis_labels("", "Body mass (g)")
151
+ # g.legend.set_title("")
152
+ return g.gcf()
153
 
154
  demo = gr.Blocks()
155