cybersalience / util.py
paulbricman's picture
feat: integrate write-up
9b67bb3
import dis
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import streamlit as st
import re
import plotly.express as px
import pandas as pd
def attend(corpus, query, model, tokenizer, blacklist=False):
token_blacklist = [119, 136, 106]
query = query
full_ids = tokenizer(corpus + '\n\n' + query,
return_tensors='pt')['input_ids']
query_ids = tokenizer(query,
return_tensors='pt')['input_ids']
corpus_ids = tokenizer(corpus + '\n\n',
return_tensors='pt')['input_ids']
attention = [[e.detach().numpy()[0]]
for e in model(full_ids)[-1]][-2]
attention = np.array([e[1:-1]
for e in np.mean(attention, axis=(0, 1))[1:-1]])
if blacklist:
prune_idx = [e_idx - 1 for e_idx, e in enumerate(
corpus_ids[0]) if e in token_blacklist]
valid = [r for r in range(attention.shape[0]) if r not in prune_idx]
attention = attention[valid][:, valid]
corpus_ids = [[e for e in corpus_ids[0] if e not in token_blacklist]]
attention = [e[:len(corpus_ids[0]) - 2]
for e in attention[-(len(query_ids[0]) - 2):]]
attention = np.mean(attention, axis=0)
corpus_tokens = tokenizer.convert_ids_to_tokens(
corpus_ids[0], skip_special_tokens=True)
# plot_attention(attention, corpus_tokens)
return corpus_tokens, attention
def plot_attention(attention, corpus_tokens):
plt.matshow(attention)
x_pos = np.arange(len(corpus_tokens))
plt.xticks(x_pos, corpus_tokens)
y_pos = np.arange(len(attention))
plt.yticks(y_pos, ['query'] * len(attention))
plt.show()
def softmax(x, temperature):
e_x = np.exp(x / temperature)
return e_x / e_x.sum()
def render_html(corpus_tokens, attention, focus=0.99):
raw = ''
distribution = [0, 0, 0]
for e_idx, e in enumerate(corpus_tokens):
if e not in '.!?':
if attention[e_idx] > 0.015 * focus:
distribution[2] += 1
raw += ' <span class="glow-large">' + e + '</span>'
elif attention[e_idx] > 0.01 * focus:
distribution[1] += 1
raw += ' <span class="glow-medium">' + e + '</span>'
elif attention[e_idx] > 0.005 * focus:
distribution[0] += 1
raw += ' <span class="glow-small">' + e + '</span>'
else:
raw += ' ' + e
else:
raw += ' ' + e
raw = re.sub(r'\s##', '', raw)
raw = re.sub(r'\s(\.|,|!|\?|;|\))', r'\1', raw)
raw = re.sub(r'\(\s', r'(', raw)
raw = re.sub(r'\s(-|\'|’)\s', r'\1', raw)
raw = re.sub(r'\s<span class="glow-(small|medium|large)">##',
r'<span class="glow-\1">', raw)
raw = raw.strip()
raw = '<p>' + raw + '</p>'
return raw
@ st.cache(allow_output_mutation=True)
def load(model='distilbert-base-cased'):
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModel.from_pretrained(model, output_attentions=True)
return tokenizer, model