joshcx commited on
Commit
e266c02
1 Parent(s): 1bc534a

Cache is not working.

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
app.py CHANGED
@@ -1,4 +1,18 @@
1
  import streamlit as st
 
2
 
3
- x = st.slider("Select a value")
4
- st.write(x, "squared is", x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from worker import WorkerClassifier
3
 
4
+ MODEL_DIR = "./models/roberta-large"
5
+ worker_clf = WorkerClassifier(MODEL_DIR)
6
+ worker_clf.init_models()
7
+
8
+ text = st.text_input(
9
+ "Worker Profile Description", "This candidate is a very warm and kind..."
10
+ )
11
+
12
+ proc_input, output = worker_clf.predict(text)
13
+
14
+ st.write(f"**Text used to classify worker profile:**")
15
+ st.write(proc_input)
16
+ st.write("**Predicted Worker Profile:**")
17
+ for i, o in zip(proc_input, output):
18
+ st.write(o[0])
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/roberta-large/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/roberta-large/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-large",
3
+ "architectures": [
4
+ "RobertaForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "id2label": {
14
+ "0": "lauren",
15
+ "1": "betty",
16
+ "2": "doris",
17
+ "3": "hailey"
18
+ },
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 4096,
21
+ "label2id": {
22
+ "betty": 1,
23
+ "doris": 2,
24
+ "hailey": 3,
25
+ "lauren": 0
26
+ },
27
+ "layer_norm_eps": 1e-05,
28
+ "max_position_embeddings": 514,
29
+ "model_type": "roberta",
30
+ "num_attention_heads": 16,
31
+ "num_hidden_layers": 24,
32
+ "pad_token_id": 1,
33
+ "position_embedding_type": "absolute",
34
+ "problem_type": "multi_label_classification",
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.18.0",
37
+ "type_vocab_size": 1,
38
+ "use_cache": true,
39
+ "vocab_size": 50265
40
+ }
models/roberta-large/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/roberta-large/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b975f3dc57dbc684675ad653ee20c79a9f27a099be183d81a710d62ea3c98e35
3
+ size 1421592557
models/roberta-large/runs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/roberta-large/runs/May10_08-20-25_aa60e833fd05/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/roberta-large/runs/May10_08-20-25_aa60e833fd05/1652170830.6680446/events.out.tfevents.1652170830.aa60e833fd05.33.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8e28066ba7d13217a65b823ec6480a476684a1896182b144e8644e7b9315dc
3
+ size 4805
models/roberta-large/runs/May10_08-20-25_aa60e833fd05/events.out.tfevents.1652170830.aa60e833fd05.33.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c19271afb20b228c47c6d9eb82445f5f9fb624e837fff6218a9a24bc1c01a6e
3
+ size 7723
models/roberta-large/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
models/roberta-large/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/roberta-large/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "roberta-large", "tokenizer_class": "RobertaTokenizer"}
models/roberta-large/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02403a15cc97a74296579258c269102ed5b7ef6097019ab4dee44236402d973
3
+ size 3055
models/roberta-large/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,104 +0,0 @@
1
- altair==4.2.0
2
- appnope==0.1.3
3
- argon2-cffi==21.3.0
4
- argon2-cffi-bindings==21.2.0
5
- asttokens==2.0.5
6
- attrs==21.4.0
7
- backcall==0.2.0
8
- backports.zoneinfo==0.2.1
9
- beautifulsoup4==4.11.1
10
- bleach==5.0.0
11
- blinker==1.4
12
- cachetools==5.0.0
13
- certifi==2021.10.8
14
- cffi==1.15.0
15
- charset-normalizer==2.0.12
16
- click==8.0.4
17
- debugpy==1.6.0
18
- decorator==5.1.1
19
- defusedxml==0.7.1
20
- entrypoints==0.4
21
- executing==0.8.3
22
- fastjsonschema==2.15.3
23
- filelock==3.6.0
24
- gitdb==4.0.9
25
- GitPython==3.1.27
26
- huggingface-hub==0.5.1
27
- idna==3.3
28
- importlib-metadata==4.11.3
29
- importlib-resources==5.7.1
30
- ipykernel==6.13.0
31
- ipython==8.3.0
32
- ipython-genutils==0.2.0
33
- ipywidgets==7.7.0
34
- jedi==0.18.1
35
- Jinja2==3.1.2
36
- joblib==1.1.0
37
- jsonschema==4.5.1
38
- jupyter-client==7.3.1
39
- jupyter-core==4.10.0
40
- jupyterlab-pygments==0.2.2
41
- jupyterlab-widgets==1.1.0
42
- MarkupSafe==2.1.1
43
- matplotlib-inline==0.1.3
44
- mistune==0.8.4
45
- nbclient==0.6.3
46
- nbconvert==6.5.0
47
- nbformat==5.4.0
48
- nest-asyncio==1.5.5
49
- notebook==6.4.11
50
- numpy==1.22.3
51
- packaging==21.3
52
- pandas==1.4.2
53
- pandocfilters==1.5.0
54
- parso==0.8.3
55
- pexpect==4.8.0
56
- pickleshare==0.7.5
57
- Pillow==9.1.0
58
- prometheus-client==0.14.1
59
- prompt-toolkit==3.0.29
60
- protobuf==3.20.1
61
- psutil==5.9.0
62
- ptyprocess==0.7.0
63
- pure-eval==0.2.2
64
- pyarrow==8.0.0
65
- pycparser==2.21
66
- pydeck==0.7.1
67
- Pygments==2.12.0
68
- Pympler==1.0.1
69
- pyparsing==3.0.8
70
- pyrsistent==0.18.1
71
- python-dateutil==2.8.2
72
- pytz==2022.1
73
- pytz-deprecation-shim==0.1.0.post0
74
- PyYAML==6.0
75
- pyzmq==22.3.0
76
- regex==2022.4.24
77
- requests==2.27.1
78
- sacremoses==0.0.53
79
- semver==2.13.0
80
- Send2Trash==1.8.0
81
- six==1.16.0
82
- smmap==5.0.0
83
- soupsieve==2.3.2.post1
84
- stack-data==0.2.0
85
- streamlit==1.9.0
86
- terminado==0.13.3
87
- tinycss2==1.1.1
88
- tokenizers==0.12.1
89
- toml==0.10.2
90
- toolz==0.11.2
91
- torch==1.11.0
92
- tornado==6.1
93
- tqdm==4.64.0
94
- traitlets==5.1.1
95
- transformers==4.18.0
96
- typing_extensions==4.2.0
97
- tzdata==2022.1
98
- tzlocal==4.2
99
- urllib3==1.26.9
100
- validators==0.19.0
101
- wcwidth==0.2.5
102
- webencodings==0.5.1
103
- widgetsnbextension==3.6.0
104
- zipp==3.8.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
worker.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tokenizers
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
+ import numpy as np
5
+ import torch
6
+ from nltk.tokenize import sent_tokenize
7
+
8
+
9
+ class WorkerClassifier:
10
+ def __init__(
11
+ self, worker_model_dir, zero_shot_model_type="facebook/bart-large-mnli"
12
+ ):
13
+ self.zero_shot = None
14
+ self.zero_shot_model_type = zero_shot_model_type
15
+ self.worker_model_dir = worker_model_dir
16
+ self.id2label = {
17
+ 0: "lauren",
18
+ 1: "betty",
19
+ 2: "doris",
20
+ 3: "hailey",
21
+ }
22
+ self.label2id = {v: k for k, v in self.id2label.items()}
23
+
24
+ def init_models(self):
25
+ self.ner = self.init_anonymizer()
26
+ self.zero_shot = self.init_zero_shot()
27
+ self.worker_model = self.init_worker_model()
28
+ self.worker_tokenizer = self.init_worker_tokenizer()
29
+
30
+ @st.cache(
31
+ hash_funcs={
32
+ torch.nn.parameter.Parameter: lambda _: None,
33
+ tokenizers.Tokenizer: lambda _: None,
34
+ tokenizers.AddedToken: lambda _: None,
35
+ },
36
+ allow_output_mutation=True,
37
+ )
38
+ def init_worker_tokenizer(self):
39
+ return AutoTokenizer.from_pretrained(self.worker_model_dir)
40
+
41
+ @st.cache(
42
+ hash_funcs={
43
+ torch.nn.parameter.Parameter: lambda _: None,
44
+ tokenizers.Tokenizer: lambda _: None,
45
+ tokenizers.AddedToken: lambda _: None,
46
+ },
47
+ allow_output_mutation=True,
48
+ )
49
+ def init_worker_model(self):
50
+ return AutoModelForSequenceClassification.from_pretrained(
51
+ self.worker_model_dir, problem_type="multi_label_classification"
52
+ )
53
+
54
+ def predict_worker(self, text, threshold=0.5):
55
+ encoding = self.worker_tokenizer(text, return_tensors="pt")
56
+ outputs = self.worker_model(**encoding)
57
+
58
+ logits = outputs["logits"]
59
+ # apply sigmoid + threshold
60
+ sigmoid = torch.nn.Sigmoid()
61
+ probs = sigmoid(logits.squeeze().cpu())
62
+ predictions = np.zeros(probs.shape)
63
+ predictions[np.where(probs >= threshold)] = 1
64
+ # turn predicted id's into actual label names
65
+ predicted_labels = [
66
+ [self.id2label[idx], probs[idx].detach().item()]
67
+ for idx, label in enumerate(predictions)
68
+ if label == 1.0
69
+ ]
70
+ return predicted_labels
71
+
72
+ @st.cache(allow_output_mutation=True)
73
+ def init_anonymizer(self):
74
+ return pipeline(task="ner")
75
+
76
+ def anonymize(self, text: str):
77
+ new_sentences = []
78
+ sentences = sent_tokenize(text)
79
+ for sent in sentences:
80
+ result = self.ner(sent, aggregation_strategy="simple")
81
+ for r in reversed(result):
82
+ if r["entity_group"] == "PER":
83
+ sent = sent[: r["start"]] + "PERSON" + sent[r["end"] :]
84
+ new_sentences.append(sent)
85
+
86
+ return " ".join(new_sentences)
87
+
88
+ @st.cache(
89
+ hash_funcs={
90
+ tokenizers.Tokenizer: lambda _: None,
91
+ tokenizers.AddedToken: lambda _: None,
92
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy(),
93
+ },
94
+ allow_output_mutation=True,
95
+ )
96
+ def init_zero_shot(self):
97
+ return pipeline(
98
+ task="zero-shot-classification", model=self.zero_shot_model_type
99
+ )
100
+
101
+ def get_personality_sentences(self, text):
102
+ new_sentences = []
103
+ sentences = sent_tokenize(text)
104
+
105
+ for sent in sentences:
106
+ if self.personality_sent_classifier(sent):
107
+ new_sentences.append(sent)
108
+ return " ".join(new_sentences)
109
+
110
+ def personality_sent_classifier(self, text, threshold=0.8):
111
+ candidate_labels = ["describing a personality trait."]
112
+ hypothesis_template = "This example is {}"
113
+
114
+ output = self.zero_shot(
115
+ text,
116
+ candidate_labels=candidate_labels,
117
+ hypothesis_template=hypothesis_template,
118
+ )
119
+ # print(f'{text} with score {output["scores"][0]}\n')
120
+ if output["scores"][0] > threshold:
121
+ return True
122
+ return False
123
+
124
+ def predict(self, text):
125
+ # first extract sentences that are relevant to personalities
126
+ text = self.get_personality_sentences(text)
127
+ extracted_text = text
128
+
129
+ # next anonymize the sentences
130
+ text = self.anonymize(text)
131
+
132
+ # classify text
133
+ text = self.predict_worker(text)
134
+ return extracted_text, text