leandro commited on
Commit
16b19cc
1 Parent(s): ad022dd

add app and requirements

Browse files
Files changed (2) hide show
  1. app.py +162 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers import pipeline
4
+ import numpy as np
5
+ import pandas as pd
6
+ import matplotlib.cm as cm
7
+ import html
8
+ from torch.nn.functional import softmax
9
+ import torch
10
+ from matplotlib.colors import LinearSegmentedColormap
11
+
12
+ cdict = {'red': [[0.0, 0.8, 0.8],
13
+ [1.0, 1.0, 1.0]],
14
+ 'green': [[0.0, 0.0, 0.0],
15
+ [1.0, 1.0, 1.0]],
16
+ 'blue': [[0.0, 0.0, 0.0],
17
+ [1.0, 1.0, 1.0]],
18
+ 'alpha':[[0.0, 1.0, 1.0],
19
+ [1.0, 0.0, 0.0]]}
20
+
21
+ cmap = LinearSegmentedColormap('codemap', segmentdata=cdict, N=256)
22
+
23
+ def value2rgba(x, cmap=cmap, alpha_mult=1.0):
24
+ c = cmap(x)
25
+ rgb = (np.array(c[:-1]) * 255).astype(int)
26
+ a = c[-1] * alpha_mult
27
+ return tuple(rgb.tolist() + [a])
28
+
29
+ def highlight_token_scores(tokens, scores, sep=' ', **kwargs):
30
+ html_code,spans = [''], []#['<span style="font-family: monospace;">'], []
31
+ for t, s in zip(tokens, scores):
32
+ t = html.escape(t)
33
+ t = t.replace("\n", " \n")
34
+ c = str(value2rgba(s, alpha_mult=0.8, **kwargs))
35
+ spans.append(f'<span title="{s:.3f}" style="background-color: rgba{c};">{t}</span>')
36
+ html_code.append(sep.join(spans))
37
+ return '<pre><code>' + ''.join(html_code)
38
+
39
+ def color_dataframe(row):
40
+ styles = []
41
+ c = str(value2rgba(row["scores"], alpha_mult=0.8))
42
+ for key in row.index:
43
+ if key in {"tokens", "scores"}:
44
+ styles.append(f"background-color: rgba{c}")
45
+ else:
46
+ styles.append(f"background-color: None")
47
+ return styles
48
+
49
+ @st.cache(allow_output_mutation=True)
50
+ def load_tokenizer(model_ckpt):
51
+ return AutoTokenizer.from_pretrained(model_ckpt)
52
+
53
+ @st.cache(allow_output_mutation=True)
54
+ def load_model(model_ckpt):
55
+ model = AutoModelForCausalLM.from_pretrained(model_ckpt)
56
+ return model
57
+
58
+ def calculate_scores(probs, token_ids):
59
+ probs = probs[:-1]
60
+ token_ids = token_ids[1:]
61
+ sorted_ids = np.argsort(probs, axis=-1)[:, ::-1]
62
+ sorted_probs = np.sort(probs, axis=-1)[:, ::-1]
63
+ selected_token_mask = sorted_ids == token_ids[:, None]
64
+ masked_probs = np.ma.array(sorted_probs, mask=~selected_token_mask)
65
+ token_probs = masked_probs.sum(axis=1).data
66
+
67
+ masked_indices = np.cumsum(selected_token_mask[:, ::-1], axis=-1)[:, ::-1].astype(bool)
68
+ masked_probs = np.ma.array(sorted_probs, mask=~masked_indices)
69
+ token_rank = masked_indices.sum(axis=-1)
70
+ cumulative_probs = masked_probs.sum(axis=1).data/token_rank
71
+ scores = token_probs/cumulative_probs
72
+ return [1.] + list(scores), sorted_ids
73
+
74
+ def calculate_loss(logits, labels):
75
+ shift_logits = logits[..., :-1, :].contiguous()
76
+ shift_labels = labels[..., 1:].contiguous()
77
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
78
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
79
+ norm_loss = 1 - (loss/torch.max(loss))
80
+ return [1.] + list(norm_loss.numpy())
81
+
82
+ default_code = """\
83
+ from torch import nn
84
+ from transformers import Model
85
+
86
+ class Transformer:
87
+ def __init__(config):
88
+ self.model = Model(config)
89
+
90
+ def forward(inputs):
91
+ return self.model(inputs)"""
92
+
93
+ solution_code = """\
94
+ from torch import nn
95
+ from transformers import Model
96
+
97
+ class Transformer(nn.Module):
98
+ def __init__(self, config):
99
+ super(Transformer, self).__init__()
100
+ self.config = config
101
+ self.model = Model(config)
102
+
103
+ def forward(self, inputs):
104
+ return self.model(inputs)
105
+ """
106
+
107
+ st.set_page_config(page_icon=':parrot:', layout="wide")
108
+
109
+ np.random.seed(42)
110
+ model_ckpt = "lvwerra/codeparrot"
111
+ tokenizer = load_tokenizer(model_ckpt)
112
+ model = load_model(model_ckpt)
113
+ st.markdown("<h1 style='text-align: center;'>CodeParrot 🦜</h1>", unsafe_allow_html=True)
114
+ st.markdown('##')
115
+
116
+ col1, col2 = st.columns(2)
117
+
118
+ col1.subheader("Edit code")
119
+ code = col1.text_area(label="", value=default_code, height=220,).strip()
120
+ inputs = tokenizer(code, return_tensors='pt')
121
+ token_list = [tokenizer.decode(t) for t in inputs["input_ids"][0]]
122
+
123
+ with torch.no_grad():
124
+ logits = model(input_ids=inputs["input_ids"]).logits[0]
125
+ probs = softmax(logits, dim=-1)
126
+
127
+ loss = calculate_loss(logits, inputs["input_ids"][0])
128
+ norm_probs, sorted_token_ids = calculate_scores(probs.numpy(), inputs["input_ids"][0].numpy())
129
+
130
+ if len(inputs['input_ids'])>1024:
131
+ st.warning("Your input is longer than the maximum 1024 tokens and will be truncated.")
132
+
133
+ st.sidebar.title("Settings:")
134
+ if st.sidebar.radio("Highlight mode:", ["Probability heuristics", "Scaled loss per token"]) == "Probability heuristics":
135
+ scores = norm_probs
136
+ else:
137
+ scores = loss
138
+
139
+ suggestion_threshold = st.sidebar.slider("Suggestion threshold", 0.0, 1.0, 0.2)
140
+
141
+ col2.subheader("Highlighted code")
142
+ col2.markdown('##')
143
+ html_string = highlight_token_scores(token_list, scores, sep="")
144
+ col2.markdown(html_string, unsafe_allow_html=True)
145
+ col2.markdown('##')
146
+
147
+ st.subheader("Model suggestions")
148
+ top_k = {}
149
+ for i in range(5):
150
+ top_k[f"top-{i+1}"] = ["No prediction for first token"] + [repr(tokenizer.decode(idx)) for idx in sorted_token_ids[:, i]]
151
+ df = pd.DataFrame({"tokens": [repr(t) for t in token_list], "scores": scores, **top_k})
152
+ df.index.name = "position"
153
+ df_filter = df.loc[df["scores"]<=suggestion_threshold]
154
+ df_filter.reset_index(inplace=True)
155
+ df_filter = df_filter[["tokens", "scores", "position", "top-1", "top-2", "top-3", "top-4", "top-5",]]
156
+ df_filter = df_filter.style.apply(color_dataframe, axis=1)
157
+ st.dataframe(df_filter)
158
+
159
+ st.markdown('##')
160
+
161
+ st.subheader("Possible solution")
162
+ st.code(solution_code)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.12.2
2
+ pandas
3
+ matplotlib
4
+ torch