|
import gradio as gr |
|
import numpy as np |
|
import os |
|
import requests |
|
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from typing import List |
|
|
|
NER_MODEL_PATH = 'dell-research-harvard/historical_newspaper_ner' |
|
EMBED_MODEL_PATH = 'dell-research-harvard/same-story' |
|
|
|
AZURE_VMS = {} |
|
AVAILABLE_STATES = ['All States'] |
|
for k, v in os.environ.items(): |
|
if 'AZURE_VM' in k: |
|
AZURE_VMS[k.split('_')[-1]] = v |
|
AVAILABLE_STATES.append(k.split('_')[-1].capitalize()) |
|
|
|
AVAILABLE_YEARS = ['All Years'] |
|
|
|
REQUEST_HEADERS = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/37.0.2062.94 Chrome/37.0.2062.94 Safari/537.36' |
|
|
|
|
|
def find_sep_token(tokenizer): |
|
|
|
""" |
|
Returns sep token for given tokenizer |
|
""" |
|
|
|
if 'eos_token' in tokenizer.special_tokens_map: |
|
sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " " |
|
else: |
|
sep = " " + tokenizer.special_tokens_map['sep_token'] + " " |
|
|
|
return sep |
|
|
|
|
|
def find_mask_token(tokenizer): |
|
""" |
|
Returns mask token for given tokenizer |
|
|
|
""" |
|
mask_tok = tokenizer.special_tokens_map['mask_token'] |
|
|
|
return mask_tok |
|
|
|
|
|
if gr.NO_RELOAD: |
|
ner_model=AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH) |
|
ner_tokenizer=AutoTokenizer.from_pretrained(NER_MODEL_PATH, return_tensors = "pt", |
|
max_length=256, truncation = True) |
|
token_classifier = pipeline(task = "ner", |
|
model = ner_model, tokenizer = ner_tokenizer, |
|
ignore_labels = [], aggregation_strategy='max') |
|
|
|
embedding_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_PATH) |
|
embedding_model = SentenceTransformer(EMBED_MODEL_PATH) |
|
embed_mask_tok = find_mask_token(embedding_tokenizer) |
|
embed_sep_tok = find_sep_token(embedding_tokenizer) |
|
|
|
img_download_session = requests.Session() |
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_punctuation_for_generic_mask(word): |
|
"""If punctuation comes before the word, return it before the mask, ow return it after the mask""" |
|
|
|
if word[0] in [".",",","!","?"]: |
|
return word[0] + " [MASK]" |
|
elif word[-1] in [".",",","!","?"]: |
|
return "[MASK]" + word[-1] |
|
else: |
|
return "[MASK]" |
|
|
|
def handle_punctuation_for_entity_mask(word,entity_group): |
|
"""If punctuation comes before the word, return it before the mask, ow return it after the mask - this is for specific entity masks""" |
|
|
|
if word[0] in [".",",","!","?"]: |
|
return word[0]+" "+entity_group |
|
elif word[-1] in [".",",","!","?"]: |
|
return entity_group+word[-1] |
|
else: |
|
return entity_group |
|
|
|
|
|
def replace_words_with_entity_tokens(ner_output_dict: List[dict], |
|
desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], |
|
all_masks_same: bool = True) -> str: |
|
|
|
if not all_masks_same: |
|
new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_entity_mask(subdict["word"],subdict["entity_group"]) for subdict in ner_output_dict] |
|
else: |
|
new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_generic_mask(subdict["word"]) for subdict in ner_output_dict] |
|
|
|
return " ".join(new_word_list) |
|
|
|
def mask(ner_output_list: List[List[dict]], desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], |
|
all_masks_same: bool = True) -> List[str]: |
|
|
|
return replace_words_with_entity_tokens(ner_output_list, desired_labels, all_masks_same) |
|
|
|
|
|
def ner(text: List[str]) -> List[str]: |
|
results = token_classifier(text) |
|
return results[0] |
|
|
|
def ner_and_mask(text: List[str], labels_to_mask: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], all_masks_same: bool = True) -> List[str]: |
|
ner_output_list = ner(text) |
|
|
|
return mask(ner_output_list, labels_to_mask, all_masks_same) |
|
|
|
|
|
def embed(text: str) -> List[str]: |
|
data = [] |
|
|
|
text = text.replace('[MASK]', embed_mask_tok) |
|
text = text.replace('[SEP]', embed_sep_tok) |
|
data.append(text) |
|
|
|
embedding = embedding_model.encode(data, show_progress_bar = False, batch_size = 1) |
|
embedding = embedding / np.linalg.norm(embedding, axis = 1, keepdims = True) |
|
|
|
return embedding |
|
|
|
def query(sentence: str, state: str, years: List[str]) -> List[str]: |
|
mask_results = ner_and_mask([sentence]) |
|
embedding = embed(mask_results) |
|
|
|
|
|
assert embedding.shape == (1, 768) |
|
embedding = embedding[0].astype(np.float64) |
|
req = {"vector": list(embedding), 'nn': 5} |
|
|
|
if state == 'All States': |
|
pass |
|
else: |
|
vm_address = AZURE_VMS[state.upper()] |
|
|
|
response = requests.post(f"http://{vm_address}/retrieve", json = req) |
|
|
|
doc = response.json() |
|
article = doc['bboxes'][int(doc['article_id'])] |
|
if len(doc['lccn']['dbpedia_ids']) == 0: |
|
location = 'Unknown' |
|
else: |
|
location = doc['lccn']['dbpedia_ids'][0].replace('%2C_', ', ') |
|
|
|
|
|
|
|
results = { |
|
'newspaper_name': doc['lccn']['title'], |
|
'location': location, |
|
'date': doc['scan']['date'], |
|
'article_text': article['raw_text'], |
|
'pdf_link': doc['scan']['jp2_url'].replace('jp2', 'pdf') |
|
} |
|
|
|
return results['newspaper_name'], results['location'], results['date'], results['article_text'], results['pdf_link'] |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = gr.Interface( |
|
fn=query, |
|
inputs=[ |
|
gr.Textbox(lines=10, label="News Article"), |
|
gr.Dropdown(AVAILABLE_STATES, label="States to Search"), |
|
gr.CheckboxGroup(AVAILABLE_YEARS, label="Years to Search") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Newspaper Name"), |
|
gr.Textbox(label="Location"), |
|
gr.Textbox(label="Date"), |
|
gr.Textbox(lines = 10, label="Article Text OCR"), |
|
gr.Textbox(label="PDF Link") |
|
] |
|
) |
|
|
|
demo.launch() |