import dis
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import streamlit as st
import re
import plotly.express as px
import pandas as pd
def attend(corpus, query, model, tokenizer, blacklist=False):
token_blacklist = [119, 136, 106]
query = query
full_ids = tokenizer(corpus + '\n\n' + query,
return_tensors='pt')['input_ids']
query_ids = tokenizer(query,
return_tensors='pt')['input_ids']
corpus_ids = tokenizer(corpus + '\n\n',
return_tensors='pt')['input_ids']
attention = [[e.detach().numpy()[0]]
for e in model(full_ids)[-1]][-2]
attention = np.array([e[1:-1]
for e in np.mean(attention, axis=(0, 1))[1:-1]])
if blacklist:
prune_idx = [e_idx - 1 for e_idx, e in enumerate(
corpus_ids[0]) if e in token_blacklist]
valid = [r for r in range(attention.shape[0]) if r not in prune_idx]
attention = attention[valid][:, valid]
corpus_ids = [[e for e in corpus_ids[0] if e not in token_blacklist]]
attention = [e[:len(corpus_ids[0]) - 2]
for e in attention[-(len(query_ids[0]) - 2):]]
attention = np.mean(attention, axis=0)
corpus_tokens = tokenizer.convert_ids_to_tokens(
corpus_ids[0], skip_special_tokens=True)
# plot_attention(attention, corpus_tokens)
return corpus_tokens, attention
def plot_attention(attention, corpus_tokens):
plt.matshow(attention)
x_pos = np.arange(len(corpus_tokens))
plt.xticks(x_pos, corpus_tokens)
y_pos = np.arange(len(attention))
plt.yticks(y_pos, ['query'] * len(attention))
plt.show()
def softmax(x, temperature):
e_x = np.exp(x / temperature)
return e_x / e_x.sum()
def render_html(corpus_tokens, attention, focus=0.99):
raw = ''
distribution = [0, 0, 0]
for e_idx, e in enumerate(corpus_tokens):
if e not in '.!?':
if attention[e_idx] > 0.015 * focus:
distribution[2] += 1
raw += ' ' + e + ''
elif attention[e_idx] > 0.01 * focus:
distribution[1] += 1
raw += ' ' + e + ''
elif attention[e_idx] > 0.005 * focus:
distribution[0] += 1
raw += ' ' + e + ''
else:
raw += ' ' + e
else:
raw += ' ' + e
raw = re.sub(r'\s##', '', raw)
raw = re.sub(r'\s(\.|,|!|\?|;|\))', r'\1', raw)
raw = re.sub(r'\(\s', r'(', raw)
raw = re.sub(r'\s(-|\'|’)\s', r'\1', raw)
raw = re.sub(r'\s##',
r'', raw)
raw = raw.strip()
raw = '' + raw + '
'
return raw
@ st.cache(allow_output_mutation=True)
def load(model='distilbert-base-cased'):
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModel.from_pretrained(model, output_attentions=True)
return tokenizer, model