newsdejavu / app.py
tombryan's picture
wrong name
f8367ae
raw
history blame
6.1 kB
import gradio as gr
import numpy as np
import os
import requests
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
from sentence_transformers import SentenceTransformer
from typing import List
NER_MODEL_PATH = 'dell-research-harvard/historical_newspaper_ner'
EMBED_MODEL_PATH = 'dell-research-harvard/same-story'
AZURE_VMS = {}
AVAILABLE_STATES = ['All States']
for k, v in os.environ.items():
if 'AZURE_VM' in k:
AZURE_VMS[k.split('_')[-1]] = v
AVAILABLE_STATES.append(k.split('_')[-1].capitalize())
AVAILABLE_YEARS = ['All Years']
def find_sep_token(tokenizer):
"""
Returns sep token for given tokenizer
"""
if 'eos_token' in tokenizer.special_tokens_map:
sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " "
else:
sep = " " + tokenizer.special_tokens_map['sep_token'] + " "
return sep
def find_mask_token(tokenizer):
"""
Returns mask token for given tokenizer
"""
mask_tok = tokenizer.special_tokens_map['mask_token']
return mask_tok
if gr.NO_RELOAD:
ner_model=AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_tokenizer=AutoTokenizer.from_pretrained(NER_MODEL_PATH, return_tensors = "pt",
max_length=256, truncation = True)
token_classifier = pipeline(task = "ner",
model = ner_model, tokenizer = ner_tokenizer,
ignore_labels = [], aggregation_strategy='max')
embedding_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_PATH)
embedding_model = SentenceTransformer(EMBED_MODEL_PATH)
embed_mask_tok = find_mask_token(embedding_tokenizer)
embed_sep_tok = find_sep_token(embedding_tokenizer)
# with open(REF_INDEX_PATH, 'r') as f:
# news_paths = [l.strip() for l in f.readlines()]
def handle_punctuation_for_generic_mask(word):
"""If punctuation comes before the word, return it before the mask, ow return it after the mask"""
if word[0] in [".",",","!","?"]:
return word[0] + " [MASK]"
elif word[-1] in [".",",","!","?"]:
return "[MASK]" + word[-1]
else:
return "[MASK]"
def handle_punctuation_for_entity_mask(word,entity_group):
"""If punctuation comes before the word, return it before the mask, ow return it after the mask - this is for specific entity masks"""
if word[0] in [".",",","!","?"]:
return word[0]+" "+entity_group
elif word[-1] in [".",",","!","?"]:
return entity_group+word[-1]
else:
return entity_group
def replace_words_with_entity_tokens(ner_output_dict: List[dict],
desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'],
all_masks_same: bool = True) -> str:
if not all_masks_same:
new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_entity_mask(subdict["word"],subdict["entity_group"]) for subdict in ner_output_dict]
else:
new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_generic_mask(subdict["word"]) for subdict in ner_output_dict]
return " ".join(new_word_list)
def mask(ner_output_list: List[List[dict]], desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'],
all_masks_same: bool = True) -> List[str]:
return replace_words_with_entity_tokens(ner_output_list, desired_labels, all_masks_same)
def ner(text: List[str]) -> List[str]:
results = token_classifier(text)
return results[0]
def ner_and_mask(text: List[str], labels_to_mask: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], all_masks_same: bool = True) -> List[str]:
ner_output_list = ner(text)
return mask(ner_output_list, labels_to_mask, all_masks_same)
def embed(text: str) -> List[str]:
data = []
# Correct [MASK] token for tokenizer
text = text.replace('[MASK]', embed_mask_tok)
text = text.replace('[SEP]', embed_sep_tok)
data.append(text)
embedding = embedding_model.encode(data, show_progress_bar = False, batch_size = 1)
embedding = embedding / np.linalg.norm(embedding, axis = 1, keepdims = True)
return embedding
def query(sentence: str, state: str, years: List[str]) -> List[str]:
mask_results = ner_and_mask([sentence])
embedding = embed(mask_results)
assert embedding.shape == (1, 768)
embedding = embedding[0].astype(np.float64)
req = {"vector": list(embedding), 'nn': 5}
if state == 'All States':
pass
else:
vm_address = AZURE_VMS[state.upper()]
# Send embedding to Azure VM
response = requests.post(f"http://{vm_address}/retrieve", json = req)
doc = response.json()
article = doc['bboxes'][int(doc['article_id'])]
if len(article['lccn']['dpedia_ids']) == 0:
location = 'Unknown'
else:
location = doc['lccn']['dpedia_ids'][0].replace('%2C_', ', ')
results = {
'newspaper_name': doc['lccn']['title'],
'location': location,
'date': doc['scan']['date'],
'article_text': article['raw_text'],
'pdf_link': doc['scan']['jp2_url'].replace('jp2', 'pdf')
}
return results['newspaper_name'], results['location'], results['date'], results['article_text'], results['pdf_link']
if __name__ == "__main__":
demo = gr.Interface(
fn=query,
inputs=[
gr.Textbox(lines=10, label="News Article"),
gr.Dropdown(AVAILABLE_STATES, label="States to Search"),
gr.CheckboxGroup(AVAILABLE_YEARS, label="Years to Search")
],
outputs=[
gr.Textbox(label="Newspaper Name"),
gr.Textbox(label="Location"),
gr.Textbox(label="Date"),
gr.Textbox(lines = 10, label="Article Text OCR"),
gr.Textbox(label="PDF Link")
]
)
demo.launch()