cswamy commited on
Commit
84c08ac
1 Parent(s): 98656aa

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ bertbasecased_finetuned_conll.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import re
4
+ import torch
5
+
6
+ from model import create_bertcased_ner
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Read class names from class_names.txt
11
+ with open("class_names.txt", "r") as f:
12
+ class_names = [entity_name.strip() for entity_name in f.readlines()]
13
+
14
+ # Setup model and tokenizer
15
+ model, tokenizer = create_bertcased_ner(class_names)
16
+
17
+ # Load state dict from model
18
+ model.load_state_dict(
19
+ torch.load(
20
+ f="bertbasecased_finetuned_conll.pth",
21
+ map_location=torch.device("cpu")
22
+ ))
23
+
24
+ # Predict function
25
+ def predict(new_text:str,
26
+ model:torch.nn.Module,
27
+ tokenizer,
28
+ device:torch.device):
29
+ """
30
+ Function for named entity recognition on new text.
31
+ Args:
32
+ new_text(str): A new sentence to classify entities on.
33
+ model(torch.nn.Module): Trained pytorch model for NER.
34
+ tokenizer: tokenizer for the model.
35
+ device(torch.device): Device setting
36
+ Returns:
37
+ List of dicts with words and entities in text.
38
+ """
39
+ # Start timer
40
+ start_time = timer()
41
+
42
+ new_text_tokens = new_text.split(' ')
43
+ tokenized_sample = tokenizer(new_text_tokens, is_split_into_words=True)
44
+ input_to_model = {k: torch.tensor(v).unsqueeze(dim=0).to(device) for k, v in tokenized_sample.items()}
45
+ outputs = model(**input_to_model)
46
+ preds = torch.argmax(outputs.logits, dim=-1)
47
+ preds_list = preds.squeeze(dim=0).tolist()
48
+
49
+ # Remove CLS and SEP tokens from all lists
50
+ tokenized_tokens = tokenized_sample.tokens()[1:-1]
51
+ word_ids = tokenized_sample.word_ids()[1:-1]
52
+ preds_list = preds_list[1:-1]
53
+
54
+ # Remove pred = 0 from tokens and word ids (0's are non-entities)
55
+ ix_remove = []
56
+ for i, pred in enumerate(preds_list):
57
+ if pred == 0:
58
+ ix_remove.append(i)
59
+ filtered_tokens = [tokenized_tokens[t] for t in range(len(tokenized_tokens)) if t not in ix_remove]
60
+ filtered_wordids = [word_ids[w] for w in range(len(word_ids)) if w not in ix_remove]
61
+ filtered_preds = [preds_list[p] for p in range(len(preds_list)) if p not in ix_remove]
62
+
63
+ # Create list with words from original text and predictions
64
+ current_word = None
65
+ results_list = []
66
+ for i, word in enumerate(filtered_wordids):
67
+ if word != current_word:
68
+ if filtered_preds[i] % 2 == 1:
69
+ results_dict = {}
70
+ results_dict["word"] = re.sub(r'[^\w\s]', '', new_text_tokens[word])
71
+ results_dict["pred"] = filtered_preds[i]
72
+ results_list.append(results_dict)
73
+ current_word = word
74
+ else:
75
+ tmp_dict = results_list[-1]
76
+ tmp_dict["word"] = new_text_tokens[word-1] + ' ' + new_text_tokens[word]
77
+
78
+ # Finally convert predictions to entity categories
79
+ # Person, Organization, Location and Miscellaneous
80
+ for pred in results_list:
81
+ if pred["pred"] <= 2:
82
+ pred["pred"] = "Person"
83
+ elif pred["pred"] <= 4:
84
+ pred["pred"] = "Organisation"
85
+ elif pred["pred"] <= 6:
86
+ pred["pred"] = "Location"
87
+ else:
88
+ pred["pred"] = "Miscellaneous"
89
+
90
+ # Calculate prediction time
91
+ pred_time = round(timer() - start_time, 5)
92
+
93
+ return results_list, pred_time
94
+
95
+ # Create custom display function
96
+ def display(results_list):
97
+ table_html = "<table>"
98
+ table_html += "<tr><th>Word</th><th>Prediction</th></tr>"
99
+
100
+ for item in results_list:
101
+ table_html += f"<tr><td>{item['word']}</td><td>{item['pred']}</td></tr>"
102
+
103
+ table_html += "</table>"
104
+
105
+ return table_html
106
+
107
+ # Create examples list
108
+ examples_list = ["Barack Obama was the 44th President of the United States.",
109
+ "Islington is a borough in the city of London.",
110
+ "United Nations is headquartered in New York City."]
111
+
112
+ # Create Gradio app
113
+ title = "Named Entity Recognition 🔎"
114
+ description = "Bert finetuned model for named entity recognition!"
115
+ article = "Finetuned on the conll2003 dataset"\
116
+
117
+ demo = gr.Interface(fn=predict,
118
+ inputs=gr.Textbox(placeholder="Enter sentence here..."),
119
+ outputs=[gr.outputs.HTML(display),
120
+ gr.Number(label="Prediction_time (sec)")],
121
+ examples=examples_list,
122
+ title=title,
123
+ description=description,
124
+ article=article)
125
+
126
+ # Launch gradio
127
+ demo.launch()
bertbasecased_finetuned_conll.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bcbad40b0b3f525fccddac6366579b3e80e1565ed799ccea8b7556e9f8c1ca
3
+ size 430991289
class_names.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ O
2
+ B-PER
3
+ I-PER
4
+ B-ORG
5
+ I-ORG
6
+ B-LOC
7
+ I-LOC
8
+ B-MISC
9
+ I-MISC
model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
2
+
3
+ def create_bertcased_ner(class_names):
4
+ """
5
+ Initializes tokenizer and model for a bert-cased checkpoint.
6
+ Args:
7
+ class_names: List of classnames
8
+ Returns:
9
+ Instance of model and tokenizer
10
+ """
11
+ checkpoint = "bert-base-cased"
12
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
13
+
14
+ # Setup id2label and label2id dicts
15
+ id2label = {i: label for i, label in enumerate(class_names)}
16
+ label2id = {label: i for i, label in enumerate(class_names)}
17
+
18
+ # Instantiate model
19
+ model = AutoModelForTokenClassification.from_pretrained(checkpoint,
20
+ id2label=id2label,
21
+ label2id=label2id)
22
+
23
+ return model, tokenizer
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.12.0
2
+ gradio==3.1.4