import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import pipeline import numpy as np import pandas as pd import matplotlib.cm as cm import html from torch.nn.functional import softmax import torch from matplotlib.colors import LinearSegmentedColormap cdict = {'red': [[0.0, 0.8, 0.8], [1.0, 1.0, 1.0]], 'green': [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], 'blue': [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], 'alpha':[[0.0, 1.0, 1.0], [1.0, 0.0, 0.0]]} cmap = LinearSegmentedColormap('codemap', segmentdata=cdict, N=256) def value2rgba(x, cmap=cmap, alpha_mult=1.0): c = cmap(x) rgb = (np.array(c[:-1]) * 255).astype(int) a = c[-1] * alpha_mult return tuple(rgb.tolist() + [a]) def highlight_token_scores(tokens, scores, sep=' ', **kwargs): html_code,spans = [''], []#[''], [] for t, s in zip(tokens, scores): t = html.escape(t) t = t.replace("\n", " \n") c = str(value2rgba(s, alpha_mult=0.8, **kwargs)) spans.append(f'{t}') html_code.append(sep.join(spans)) return '
' + ''.join(html_code)

def color_dataframe(row):
    styles = []
    c = str(value2rgba(row["scores"], alpha_mult=0.8))
    for key in row.index:
        if key in {"tokens", "scores"}:
            styles.append(f"background-color: rgba{c}") 
        else:
            styles.append(f"background-color: None") 
    return styles

@st.cache(allow_output_mutation=True)
def load_tokenizer(model_ckpt):
    return AutoTokenizer.from_pretrained(model_ckpt)

@st.cache(allow_output_mutation=True)
def load_model(model_ckpt):
    model = AutoModelForCausalLM.from_pretrained(model_ckpt)
    return model

def calculate_scores(probs, token_ids):
    probs = probs[:-1]
    token_ids = token_ids[1:]
    sorted_ids = np.argsort(probs, axis=-1)[:, ::-1]
    sorted_probs = np.sort(probs, axis=-1)[:, ::-1]
    selected_token_mask = sorted_ids == token_ids[:, None]
    masked_probs = np.ma.array(sorted_probs, mask=~selected_token_mask)
    token_probs = masked_probs.sum(axis=1).data

    masked_indices = np.cumsum(selected_token_mask[:, ::-1], axis=-1)[:, ::-1].astype(bool)
    masked_probs = np.ma.array(sorted_probs, mask=~masked_indices)
    token_rank = masked_indices.sum(axis=-1)
    cumulative_probs = masked_probs.sum(axis=1).data/token_rank
    scores = token_probs/cumulative_probs
    return [1.] + list(scores), sorted_ids

def calculate_loss(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    norm_loss = 1 - (loss/torch.max(loss))
    return [1.] + list(norm_loss.numpy())

default_code = """\
from torch import nn
from transformers import Model

class Transformer:
    def __init__(config):
        self.model = Model(config)

    def forward(inputs):
        return self.model(inputs)"""

solution_code = """\
from torch import nn
from transformers import Model

class Transformer(nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.config = config
        self.model = Model(config)

    def forward(self, inputs):
        return self.model(inputs)
"""

st.set_page_config(page_icon=':parrot:', layout="wide")

np.random.seed(42)
model_ckpt = "lvwerra/codeparrot"
tokenizer = load_tokenizer(model_ckpt)
model = load_model(model_ckpt)
st.markdown("

CodeParrot 🦜

", unsafe_allow_html=True) st.markdown('##') col1, col2 = st.columns(2) col1.subheader("Edit code") code = col1.text_area(label="", value=default_code, height=220,).strip() inputs = tokenizer(code, return_tensors='pt') token_list = [tokenizer.decode(t) for t in inputs["input_ids"][0]] with torch.no_grad(): logits = model(input_ids=inputs["input_ids"]).logits[0] probs = softmax(logits, dim=-1) loss = calculate_loss(logits, inputs["input_ids"][0]) norm_probs, sorted_token_ids = calculate_scores(probs.numpy(), inputs["input_ids"][0].numpy()) if len(inputs['input_ids'])>1024: st.warning("Your input is longer than the maximum 1024 tokens and will be truncated.") st.sidebar.title("Settings:") if st.sidebar.radio("Highlight mode:", ["Probability heuristics", "Scaled loss per token"]) == "Probability heuristics": scores = norm_probs else: scores = loss suggestion_threshold = st.sidebar.slider("Suggestion threshold", 0.0, 1.0, 0.2) col2.subheader("Highlighted code") col2.markdown('##') html_string = highlight_token_scores(token_list, scores, sep="") col2.markdown(html_string, unsafe_allow_html=True) col2.markdown('##') st.subheader("Model suggestions") top_k = {} for i in range(5): top_k[f"top-{i+1}"] = ["No prediction for first token"] + [repr(tokenizer.decode(idx)) for idx in sorted_token_ids[:, i]] df = pd.DataFrame({"tokens": [repr(t) for t in token_list], "scores": scores, **top_k}) df.index.name = "position" df_filter = df.loc[df["scores"]<=suggestion_threshold] df_filter.reset_index(inplace=True) df_filter = df_filter[["tokens", "scores", "position", "top-1", "top-2", "top-3", "top-4", "top-5",]] df_filter = df_filter.style.apply(color_dataframe, axis=1) st.dataframe(df_filter) st.markdown('##') st.subheader("Possible solution") st.code(solution_code)