Cache is not working.
Browse files- .DS_Store +0 -0
- app.py +16 -2
- models/.DS_Store +0 -0
- models/roberta-large/.DS_Store +0 -0
- models/roberta-large/config.json +40 -0
- models/roberta-large/merges.txt +0 -0
- models/roberta-large/pytorch_model.bin +3 -0
- models/roberta-large/runs/.DS_Store +0 -0
- models/roberta-large/runs/May10_08-20-25_aa60e833fd05/.DS_Store +0 -0
- models/roberta-large/runs/May10_08-20-25_aa60e833fd05/1652170830.6680446/events.out.tfevents.1652170830.aa60e833fd05.33.1 +3 -0
- models/roberta-large/runs/May10_08-20-25_aa60e833fd05/events.out.tfevents.1652170830.aa60e833fd05.33.0 +3 -0
- models/roberta-large/special_tokens_map.json +1 -0
- models/roberta-large/tokenizer.json +0 -0
- models/roberta-large/tokenizer_config.json +1 -0
- models/roberta-large/training_args.bin +3 -0
- models/roberta-large/vocab.json +0 -0
- requirements.txt +0 -104
- worker.py +134 -0
.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
app.py
CHANGED
@@ -1,4 +1,18 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from worker import WorkerClassifier
|
3 |
|
4 |
+
MODEL_DIR = "./models/roberta-large"
|
5 |
+
worker_clf = WorkerClassifier(MODEL_DIR)
|
6 |
+
worker_clf.init_models()
|
7 |
+
|
8 |
+
text = st.text_input(
|
9 |
+
"Worker Profile Description", "This candidate is a very warm and kind..."
|
10 |
+
)
|
11 |
+
|
12 |
+
proc_input, output = worker_clf.predict(text)
|
13 |
+
|
14 |
+
st.write(f"**Text used to classify worker profile:**")
|
15 |
+
st.write(proc_input)
|
16 |
+
st.write("**Predicted Worker Profile:**")
|
17 |
+
for i, o in zip(proc_input, output):
|
18 |
+
st.write(o[0])
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/roberta-large/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/roberta-large/config.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "roberta-large",
|
3 |
+
"architectures": [
|
4 |
+
"RobertaForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"classifier_dropout": null,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "gelu",
|
11 |
+
"hidden_dropout_prob": 0.1,
|
12 |
+
"hidden_size": 1024,
|
13 |
+
"id2label": {
|
14 |
+
"0": "lauren",
|
15 |
+
"1": "betty",
|
16 |
+
"2": "doris",
|
17 |
+
"3": "hailey"
|
18 |
+
},
|
19 |
+
"initializer_range": 0.02,
|
20 |
+
"intermediate_size": 4096,
|
21 |
+
"label2id": {
|
22 |
+
"betty": 1,
|
23 |
+
"doris": 2,
|
24 |
+
"hailey": 3,
|
25 |
+
"lauren": 0
|
26 |
+
},
|
27 |
+
"layer_norm_eps": 1e-05,
|
28 |
+
"max_position_embeddings": 514,
|
29 |
+
"model_type": "roberta",
|
30 |
+
"num_attention_heads": 16,
|
31 |
+
"num_hidden_layers": 24,
|
32 |
+
"pad_token_id": 1,
|
33 |
+
"position_embedding_type": "absolute",
|
34 |
+
"problem_type": "multi_label_classification",
|
35 |
+
"torch_dtype": "float32",
|
36 |
+
"transformers_version": "4.18.0",
|
37 |
+
"type_vocab_size": 1,
|
38 |
+
"use_cache": true,
|
39 |
+
"vocab_size": 50265
|
40 |
+
}
|
models/roberta-large/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/roberta-large/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b975f3dc57dbc684675ad653ee20c79a9f27a099be183d81a710d62ea3c98e35
|
3 |
+
size 1421592557
|
models/roberta-large/runs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/roberta-large/runs/May10_08-20-25_aa60e833fd05/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/roberta-large/runs/May10_08-20-25_aa60e833fd05/1652170830.6680446/events.out.tfevents.1652170830.aa60e833fd05.33.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec8e28066ba7d13217a65b823ec6480a476684a1896182b144e8644e7b9315dc
|
3 |
+
size 4805
|
models/roberta-large/runs/May10_08-20-25_aa60e833fd05/events.out.tfevents.1652170830.aa60e833fd05.33.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c19271afb20b228c47c6d9eb82445f5f9fb624e837fff6218a9a24bc1c01a6e
|
3 |
+
size 7723
|
models/roberta-large/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
|
models/roberta-large/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/roberta-large/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "roberta-large", "tokenizer_class": "RobertaTokenizer"}
|
models/roberta-large/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e02403a15cc97a74296579258c269102ed5b7ef6097019ab4dee44236402d973
|
3 |
+
size 3055
|
models/roberta-large/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
@@ -1,104 +0,0 @@
|
|
1 |
-
altair==4.2.0
|
2 |
-
appnope==0.1.3
|
3 |
-
argon2-cffi==21.3.0
|
4 |
-
argon2-cffi-bindings==21.2.0
|
5 |
-
asttokens==2.0.5
|
6 |
-
attrs==21.4.0
|
7 |
-
backcall==0.2.0
|
8 |
-
backports.zoneinfo==0.2.1
|
9 |
-
beautifulsoup4==4.11.1
|
10 |
-
bleach==5.0.0
|
11 |
-
blinker==1.4
|
12 |
-
cachetools==5.0.0
|
13 |
-
certifi==2021.10.8
|
14 |
-
cffi==1.15.0
|
15 |
-
charset-normalizer==2.0.12
|
16 |
-
click==8.0.4
|
17 |
-
debugpy==1.6.0
|
18 |
-
decorator==5.1.1
|
19 |
-
defusedxml==0.7.1
|
20 |
-
entrypoints==0.4
|
21 |
-
executing==0.8.3
|
22 |
-
fastjsonschema==2.15.3
|
23 |
-
filelock==3.6.0
|
24 |
-
gitdb==4.0.9
|
25 |
-
GitPython==3.1.27
|
26 |
-
huggingface-hub==0.5.1
|
27 |
-
idna==3.3
|
28 |
-
importlib-metadata==4.11.3
|
29 |
-
importlib-resources==5.7.1
|
30 |
-
ipykernel==6.13.0
|
31 |
-
ipython==8.3.0
|
32 |
-
ipython-genutils==0.2.0
|
33 |
-
ipywidgets==7.7.0
|
34 |
-
jedi==0.18.1
|
35 |
-
Jinja2==3.1.2
|
36 |
-
joblib==1.1.0
|
37 |
-
jsonschema==4.5.1
|
38 |
-
jupyter-client==7.3.1
|
39 |
-
jupyter-core==4.10.0
|
40 |
-
jupyterlab-pygments==0.2.2
|
41 |
-
jupyterlab-widgets==1.1.0
|
42 |
-
MarkupSafe==2.1.1
|
43 |
-
matplotlib-inline==0.1.3
|
44 |
-
mistune==0.8.4
|
45 |
-
nbclient==0.6.3
|
46 |
-
nbconvert==6.5.0
|
47 |
-
nbformat==5.4.0
|
48 |
-
nest-asyncio==1.5.5
|
49 |
-
notebook==6.4.11
|
50 |
-
numpy==1.22.3
|
51 |
-
packaging==21.3
|
52 |
-
pandas==1.4.2
|
53 |
-
pandocfilters==1.5.0
|
54 |
-
parso==0.8.3
|
55 |
-
pexpect==4.8.0
|
56 |
-
pickleshare==0.7.5
|
57 |
-
Pillow==9.1.0
|
58 |
-
prometheus-client==0.14.1
|
59 |
-
prompt-toolkit==3.0.29
|
60 |
-
protobuf==3.20.1
|
61 |
-
psutil==5.9.0
|
62 |
-
ptyprocess==0.7.0
|
63 |
-
pure-eval==0.2.2
|
64 |
-
pyarrow==8.0.0
|
65 |
-
pycparser==2.21
|
66 |
-
pydeck==0.7.1
|
67 |
-
Pygments==2.12.0
|
68 |
-
Pympler==1.0.1
|
69 |
-
pyparsing==3.0.8
|
70 |
-
pyrsistent==0.18.1
|
71 |
-
python-dateutil==2.8.2
|
72 |
-
pytz==2022.1
|
73 |
-
pytz-deprecation-shim==0.1.0.post0
|
74 |
-
PyYAML==6.0
|
75 |
-
pyzmq==22.3.0
|
76 |
-
regex==2022.4.24
|
77 |
-
requests==2.27.1
|
78 |
-
sacremoses==0.0.53
|
79 |
-
semver==2.13.0
|
80 |
-
Send2Trash==1.8.0
|
81 |
-
six==1.16.0
|
82 |
-
smmap==5.0.0
|
83 |
-
soupsieve==2.3.2.post1
|
84 |
-
stack-data==0.2.0
|
85 |
-
streamlit==1.9.0
|
86 |
-
terminado==0.13.3
|
87 |
-
tinycss2==1.1.1
|
88 |
-
tokenizers==0.12.1
|
89 |
-
toml==0.10.2
|
90 |
-
toolz==0.11.2
|
91 |
-
torch==1.11.0
|
92 |
-
tornado==6.1
|
93 |
-
tqdm==4.64.0
|
94 |
-
traitlets==5.1.1
|
95 |
-
transformers==4.18.0
|
96 |
-
typing_extensions==4.2.0
|
97 |
-
tzdata==2022.1
|
98 |
-
tzlocal==4.2
|
99 |
-
urllib3==1.26.9
|
100 |
-
validators==0.19.0
|
101 |
-
wcwidth==0.2.5
|
102 |
-
webencodings==0.5.1
|
103 |
-
widgetsnbextension==3.6.0
|
104 |
-
zipp==3.8.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
worker.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import tokenizers
|
3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from nltk.tokenize import sent_tokenize
|
7 |
+
|
8 |
+
|
9 |
+
class WorkerClassifier:
|
10 |
+
def __init__(
|
11 |
+
self, worker_model_dir, zero_shot_model_type="facebook/bart-large-mnli"
|
12 |
+
):
|
13 |
+
self.zero_shot = None
|
14 |
+
self.zero_shot_model_type = zero_shot_model_type
|
15 |
+
self.worker_model_dir = worker_model_dir
|
16 |
+
self.id2label = {
|
17 |
+
0: "lauren",
|
18 |
+
1: "betty",
|
19 |
+
2: "doris",
|
20 |
+
3: "hailey",
|
21 |
+
}
|
22 |
+
self.label2id = {v: k for k, v in self.id2label.items()}
|
23 |
+
|
24 |
+
def init_models(self):
|
25 |
+
self.ner = self.init_anonymizer()
|
26 |
+
self.zero_shot = self.init_zero_shot()
|
27 |
+
self.worker_model = self.init_worker_model()
|
28 |
+
self.worker_tokenizer = self.init_worker_tokenizer()
|
29 |
+
|
30 |
+
@st.cache(
|
31 |
+
hash_funcs={
|
32 |
+
torch.nn.parameter.Parameter: lambda _: None,
|
33 |
+
tokenizers.Tokenizer: lambda _: None,
|
34 |
+
tokenizers.AddedToken: lambda _: None,
|
35 |
+
},
|
36 |
+
allow_output_mutation=True,
|
37 |
+
)
|
38 |
+
def init_worker_tokenizer(self):
|
39 |
+
return AutoTokenizer.from_pretrained(self.worker_model_dir)
|
40 |
+
|
41 |
+
@st.cache(
|
42 |
+
hash_funcs={
|
43 |
+
torch.nn.parameter.Parameter: lambda _: None,
|
44 |
+
tokenizers.Tokenizer: lambda _: None,
|
45 |
+
tokenizers.AddedToken: lambda _: None,
|
46 |
+
},
|
47 |
+
allow_output_mutation=True,
|
48 |
+
)
|
49 |
+
def init_worker_model(self):
|
50 |
+
return AutoModelForSequenceClassification.from_pretrained(
|
51 |
+
self.worker_model_dir, problem_type="multi_label_classification"
|
52 |
+
)
|
53 |
+
|
54 |
+
def predict_worker(self, text, threshold=0.5):
|
55 |
+
encoding = self.worker_tokenizer(text, return_tensors="pt")
|
56 |
+
outputs = self.worker_model(**encoding)
|
57 |
+
|
58 |
+
logits = outputs["logits"]
|
59 |
+
# apply sigmoid + threshold
|
60 |
+
sigmoid = torch.nn.Sigmoid()
|
61 |
+
probs = sigmoid(logits.squeeze().cpu())
|
62 |
+
predictions = np.zeros(probs.shape)
|
63 |
+
predictions[np.where(probs >= threshold)] = 1
|
64 |
+
# turn predicted id's into actual label names
|
65 |
+
predicted_labels = [
|
66 |
+
[self.id2label[idx], probs[idx].detach().item()]
|
67 |
+
for idx, label in enumerate(predictions)
|
68 |
+
if label == 1.0
|
69 |
+
]
|
70 |
+
return predicted_labels
|
71 |
+
|
72 |
+
@st.cache(allow_output_mutation=True)
|
73 |
+
def init_anonymizer(self):
|
74 |
+
return pipeline(task="ner")
|
75 |
+
|
76 |
+
def anonymize(self, text: str):
|
77 |
+
new_sentences = []
|
78 |
+
sentences = sent_tokenize(text)
|
79 |
+
for sent in sentences:
|
80 |
+
result = self.ner(sent, aggregation_strategy="simple")
|
81 |
+
for r in reversed(result):
|
82 |
+
if r["entity_group"] == "PER":
|
83 |
+
sent = sent[: r["start"]] + "PERSON" + sent[r["end"] :]
|
84 |
+
new_sentences.append(sent)
|
85 |
+
|
86 |
+
return " ".join(new_sentences)
|
87 |
+
|
88 |
+
@st.cache(
|
89 |
+
hash_funcs={
|
90 |
+
tokenizers.Tokenizer: lambda _: None,
|
91 |
+
tokenizers.AddedToken: lambda _: None,
|
92 |
+
torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy(),
|
93 |
+
},
|
94 |
+
allow_output_mutation=True,
|
95 |
+
)
|
96 |
+
def init_zero_shot(self):
|
97 |
+
return pipeline(
|
98 |
+
task="zero-shot-classification", model=self.zero_shot_model_type
|
99 |
+
)
|
100 |
+
|
101 |
+
def get_personality_sentences(self, text):
|
102 |
+
new_sentences = []
|
103 |
+
sentences = sent_tokenize(text)
|
104 |
+
|
105 |
+
for sent in sentences:
|
106 |
+
if self.personality_sent_classifier(sent):
|
107 |
+
new_sentences.append(sent)
|
108 |
+
return " ".join(new_sentences)
|
109 |
+
|
110 |
+
def personality_sent_classifier(self, text, threshold=0.8):
|
111 |
+
candidate_labels = ["describing a personality trait."]
|
112 |
+
hypothesis_template = "This example is {}"
|
113 |
+
|
114 |
+
output = self.zero_shot(
|
115 |
+
text,
|
116 |
+
candidate_labels=candidate_labels,
|
117 |
+
hypothesis_template=hypothesis_template,
|
118 |
+
)
|
119 |
+
# print(f'{text} with score {output["scores"][0]}\n')
|
120 |
+
if output["scores"][0] > threshold:
|
121 |
+
return True
|
122 |
+
return False
|
123 |
+
|
124 |
+
def predict(self, text):
|
125 |
+
# first extract sentences that are relevant to personalities
|
126 |
+
text = self.get_personality_sentences(text)
|
127 |
+
extracted_text = text
|
128 |
+
|
129 |
+
# next anonymize the sentences
|
130 |
+
text = self.anonymize(text)
|
131 |
+
|
132 |
+
# classify text
|
133 |
+
text = self.predict_worker(text)
|
134 |
+
return extracted_text, text
|