Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,13 +4,17 @@ from datasets import load_dataset
|
|
4 |
from transformers import AutoTokenizer
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
|
10 |
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
|
11 |
|
12 |
|
13 |
-
def analyze_sentence(index):
|
14 |
row = dataset[index]
|
15 |
text = row['text']
|
16 |
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
|
@@ -18,21 +22,29 @@ def analyze_sentence(index):
|
|
18 |
seq_len = attn_map_shape[1]
|
19 |
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
|
20 |
fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized)))
|
21 |
-
|
|
|
|
|
|
|
22 |
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
|
23 |
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
|
24 |
plt.ylabel('TARGET')
|
25 |
plt.xlabel('SOURCE')
|
26 |
plt.grid()
|
|
|
27 |
return fig
|
28 |
|
29 |
demo = gr.Blocks()
|
30 |
with demo:
|
31 |
with gr.Row():
|
32 |
-
|
|
|
|
|
|
|
33 |
btn = gr.Button("Run")
|
34 |
output = gr.Plot(label="Plot", container=True)
|
35 |
-
|
|
|
36 |
|
37 |
|
38 |
if __name__ == "__main__":
|
|
|
4 |
from transformers import AutoTokenizer
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
class VisType(Enum):
|
10 |
+
ALL = 'ALL'
|
11 |
|
12 |
|
13 |
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
|
14 |
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
|
15 |
|
16 |
|
17 |
+
def analyze_sentence(index, vis_type):
|
18 |
row = dataset[index]
|
19 |
text = row['text']
|
20 |
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
|
|
|
22 |
seq_len = attn_map_shape[1]
|
23 |
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
|
24 |
fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized)))
|
25 |
+
plot_data = attn_maps[:, 1:, 1:]
|
26 |
+
if vis_type == VisType.ALL:
|
27 |
+
plot_data = attn_maps.sum(0)
|
28 |
+
sns.heatmap(plot_data)
|
29 |
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
|
30 |
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
|
31 |
plt.ylabel('TARGET')
|
32 |
plt.xlabel('SOURCE')
|
33 |
plt.grid()
|
34 |
+
metrics = {k: v for k, v in record.items() if x not in ['text', 'attention_maps', 'attention_maps_shape']}
|
35 |
return fig
|
36 |
|
37 |
demo = gr.Blocks()
|
38 |
with demo:
|
39 |
with gr.Row():
|
40 |
+
sentence_dropdown = gr.Dropdown(label="Sentence",
|
41 |
+
choices=[x.split('</s> ')[1] for x in dataset['text']],
|
42 |
+
value=0, min_width=500, type='index')
|
43 |
+
vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType), value=0, type='index')
|
44 |
btn = gr.Button("Run")
|
45 |
output = gr.Plot(label="Plot", container=True)
|
46 |
+
metrics = gr.Label("Metrics")
|
47 |
+
btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics])
|
48 |
|
49 |
|
50 |
if __name__ == "__main__":
|