tombryan commited on
Commit
207077e
β€’
1 Parent(s): df8bdb6

Inital commit for ndjv app

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. __pycache__/app.cpython-38.pyc +0 -0
  3. app.py +136 -4
  4. requirements.txt +3 -0
README.md CHANGED
@@ -3,10 +3,10 @@ title: Newsdejavu
3
  emoji: πŸŒ–
4
  colorFrom: yellow
5
  colorTo: green
6
- sdk: streamlit
7
  sdk_version: 1.33.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
12
 
 
3
  emoji: πŸŒ–
4
  colorFrom: yellow
5
  colorTo: green
6
+ sdk: gradio
7
  sdk_version: 1.33.0
8
  app_file: app.py
9
+ pinned: falsesudo sudo
10
  license: apache-2.0
11
  ---
12
 
__pycache__/app.cpython-38.pyc ADDED
Binary file (2.71 kB). View file
 
app.py CHANGED
@@ -1,6 +1,138 @@
1
- # Just trying out a comment to test commits
 
 
 
 
 
2
 
3
- import streamlit as st
4
 
5
- x = st.slider('Select a value')
6
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ import requests
5
+ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
6
+ from sentence_transformers import SentenceTransformer
7
 
8
+ from typing import List
9
 
10
+ NER_MODEL_PATH = 'dell-research-harvard/historical_newspaper_ner'
11
+ EMBED_MODEL_PATH = 'dell-research-harvard/same-story'
12
+ AZURE_VM_ALABAMA = os.environ.get('AZURE_VM_ALABAMA')
13
+
14
+
15
+ def find_sep_token(tokenizer):
16
+
17
+ """
18
+ Returns sep token for given tokenizer
19
+ """
20
+
21
+ if 'eos_token' in tokenizer.special_tokens_map:
22
+ sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " "
23
+ else:
24
+ sep = " " + tokenizer.special_tokens_map['sep_token'] + " "
25
+
26
+ return sep
27
+
28
+
29
+ def find_mask_token(tokenizer):
30
+ """
31
+ Returns mask token for given tokenizer
32
+
33
+ """
34
+ mask_tok = tokenizer.special_tokens_map['mask_token']
35
+
36
+ return mask_tok
37
+
38
+
39
+ if gr.NO_RELOAD:
40
+ ner_model=AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
41
+ ner_tokenizer=AutoTokenizer.from_pretrained(NER_MODEL_PATH, return_tensors = "pt",
42
+ max_length=256, truncation = True)
43
+ token_classifier = pipeline(task = "ner",
44
+ model = ner_model, tokenizer = ner_tokenizer,
45
+ ignore_labels = [], aggregation_strategy='max')
46
+
47
+ embedding_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_PATH)
48
+ embedding_model = SentenceTransformer(EMBED_MODEL_PATH)
49
+ embed_mask_tok = find_mask_token(embedding_tokenizer)
50
+ embed_sep_tok = find_sep_token(embedding_tokenizer)
51
+
52
+ # with open(REF_INDEX_PATH, 'r') as f:
53
+ # news_paths = [l.strip() for l in f.readlines()]
54
+
55
+
56
+ def handle_punctuation_for_generic_mask(word):
57
+ """If punctuation comes before the word, return it before the mask, ow return it after the mask"""
58
+
59
+ if word[0] in [".",",","!","?"]:
60
+ return word[0] + " [MASK]"
61
+ elif word[-1] in [".",",","!","?"]:
62
+ return "[MASK]" + word[-1]
63
+ else:
64
+ return "[MASK]"
65
+
66
+ def handle_punctuation_for_entity_mask(word,entity_group):
67
+ """If punctuation comes before the word, return it before the mask, ow return it after the mask - this is for specific entity masks"""
68
+
69
+ if word[0] in [".",",","!","?"]:
70
+ return word[0]+" "+entity_group
71
+ elif word[-1] in [".",",","!","?"]:
72
+ return entity_group+word[-1]
73
+ else:
74
+ return entity_group
75
+
76
+
77
+ def replace_words_with_entity_tokens(ner_output_dict: List[dict],
78
+ desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'],
79
+ all_masks_same: bool = True) -> str:
80
+
81
+ if not all_masks_same:
82
+ new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_entity_mask(subdict["word"],subdict["entity_group"]) for subdict in ner_output_dict]
83
+ else:
84
+ new_word_list=[subdict["word"] if subdict["entity_group"] not in desired_labels else handle_punctuation_for_generic_mask(subdict["word"]) for subdict in ner_output_dict]
85
+
86
+ return " ".join(new_word_list)
87
+
88
+ def mask(ner_output_list: List[List[dict]], desired_labels: List[str] = ['PER', 'ORG', 'LOC', 'MISC'],
89
+ all_masks_same: bool = True) -> List[str]:
90
+
91
+ return replace_words_with_entity_tokens(ner_output_list, desired_labels, all_masks_same)
92
+
93
+
94
+ def ner(text: List[str]) -> List[str]:
95
+ results = token_classifier(text)
96
+ return results[0]
97
+
98
+ def ner_and_mask(text: List[str], labels_to_mask: List[str] = ['PER', 'ORG', 'LOC', 'MISC'], all_masks_same: bool = True) -> List[str]:
99
+ ner_output_list = ner(text)
100
+
101
+ return mask(ner_output_list, labels_to_mask, all_masks_same)
102
+
103
+
104
+ def embed(text: str) -> List[str]:
105
+ data = []
106
+ # Correct [MASK] token for tokenizer
107
+ text = text.replace('[MASK]', embed_mask_tok)
108
+ text = text.replace('[SEP]', embed_sep_tok)
109
+ data.append(text)
110
+
111
+ embedding = embedding_model.encode(data, show_progress_bar = False, batch_size = 1)
112
+ embedding = embedding / np.linalg.norm(embedding, axis = 1, keepdims = True)
113
+
114
+ return embedding
115
+
116
+ def query(sentence: str) -> List[str]:
117
+ mask_results = ner_and_mask([sentence])
118
+ embedding = embed(mask_results)
119
+
120
+ assert embedding.shape == (1, 768)
121
+ embedding = embedding[0].astype(np.float64)
122
+ req = {"vector": list(embedding), 'nn': 5}
123
+
124
+ # Send embedding to Azure VM
125
+ response = requests.post(f"http://{AZURE_VM_ALABAMA}/retrieve", json = req)
126
+ doc = response.json()
127
+ article = doc['bboxes'][doc['article_id']]
128
+ return article['raw_text']
129
+
130
+
131
+ if __name__ == "__main__":
132
+ demo = gr.Interface(
133
+ fn=query,
134
+ inputs=["text"],
135
+ outputs=["text"],
136
+ )
137
+
138
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ sentenceistransformers
3
+ numpy