vives's picture
Update app.py
4979bc7
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
import spacy
import pytextrank
from nlp_entities import *
import torch
import streamlit as st
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
#streamlit stuff
tags = st.text_input("Input tags separated by commas")
text = st.text_input("Input text to classify")
topkp = st.slider("Number of key phrases to extract from text", 10,30,20)
#Methods for tag processing
def pool_embeddings(out, tok):
embeddings = out["hidden_states"][-1]
attention_mask = tok['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
masked_embeddings = embeddings * mask
summed = torch.sum(masked_embeddings, 1)
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / summed_mask
return mean_pooled
import pandas as pd
def get_transcript(file):
data = pd.io.json.read_json(file)
transcript = data['results'].values[1][0]['transcript']
transcript = transcript.lower()
return transcript
def concat_tokens_tags(sentences):
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
for sentence in sentences:
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentence, max_length=64,
truncation=True, padding='max_length',
return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
tokens['KPS'].append(sentence)
# reformat list of tensors into single tensor
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
return tokens
# Process tags
if tags:
tags = [x.lower().strip() for x in tags.split(",")]
tags_tokens = concat_tokens_tags(tags)
tags_tokens.pop("KPS")
with torch.no_grad():
outputs_tags = model(**tags_tokens)
pools_tags = pool_embeddings(outputs_tags, tags_tokens).detach().numpy()
token_dict = {}
for tag,embedding in zip(tags,pools_tags):
token_dict[tag] = embedding
#Code related with processing text, extracting KPs, and doing distance to tag
def concat_tokens(sentences):
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': {}}
for sentence, values in sentences.items():
weight = values['weight']
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentence, max_length=64,
truncation=True, padding='max_length',
return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
tokens['KPS'][sentence] = weight
# reformat list of tensors into single tensor
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
return tokens
def calculate_weighted_embed_dist(out, tokens, weight, text,kp_dict, idx, exclude_text=False,exclude_words=False):
sim_dict = {}
pools = pool_embeddings_count(out, tokens, idx).detach().numpy()
for key in kp_dict.keys():
if exclude_text and text in key:
continue
if exclude_words and True in [x in key for x in text.split(" ")]:
continue
sim_dict[key] = cosine_similarity(
pools,
[kp_dict[key]]
)[0][0] * weight
return sim_dict
def pool_embeddings_count(out, tok, idx):
embeddings = out["hidden_states"][-1][idx:idx+1,:,:]
attention_mask = tok['attention_mask'][idx]
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
masked_embeddings = embeddings * mask
summed = torch.sum(masked_embeddings, 1)
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / summed_mask
return mean_pooled
import pandas as pd
def extract_tokens(text,top_kp=30):
kps = return_ners_and_kp([text], ret_ne=True)['KP']
#only process the top_kp tokens
kps = sorted(kps.items(), key= lambda x: x[1]['weight'], reverse = True)[:top_kp]
kps = {x:y for x,y in kps}
return concat_tokens(kps)
#Process text and classify it
if text and tags:
text = text.lower()
t1_tokens = extract_tokens(text, topkp)
t1_kps = t1_tokens.pop("KPS")
with torch.no_grad():
outputs = model(**t1_tokens)
tag_distance = None
for i,kp in enumerate(t1_kps):
if tag_distance is None:
tag_distance = calculate_weighted_embed_dist(outputs, t1_tokens,t1_kps[kp], kp, token_dict,i,exclude_text=False,exclude_words=False)
else:
curr = calculate_weighted_embed_dist(outputs, t1_tokens,t1_kps[kp], kp, token_dict,i,exclude_text=False,exclude_words=False)
tag_distance = {x:tag_distance[x] + curr[x] for x in tag_distance.keys()}
tag_distance = sorted(tag_distance.items(), key= lambda x: x[1], reverse = True)
tag_distance = {x:y for x,y in tag_distance}
st.json(tag_distance)