system HF staff commited on
Commit
44e043b
0 Parent(s):

Application code

Browse files
Files changed (4) hide show
  1. .gitattributes +30 -0
  2. README.md +37 -0
  3. app.py +125 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ bert_mlm_positive/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
29
+ bert_mlm_negative/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
30
+ bert_classifier/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ingsoc_censor
3
+ emoji: 👀
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
3
+ import torch
4
+ import numpy as np
5
+
6
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
7
+
8
+ st.set_page_config(page_title="INGSOC review poster", layout="centered")
9
+
10
+ @st.cache
11
+ def load_models():
12
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
13
+ bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
14
+ bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
15
+ bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
16
+ bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device))
17
+ bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device))
18
+ bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device))
19
+ return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
20
+
21
+
22
+ tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
23
+
24
+
25
+ def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
26
+ """
27
+ - split the sentence into tokens using the INGSOC-approved BERT tokenizer
28
+ - find :num_tokens: tokens with the highest ratio (see above)
29
+ - replace them with :k_best: words according to bert_mlm_positive
30
+ :return: a list of all possible strings (up to k_best * num_tokens)
31
+ """
32
+ input = tokenizer(sentence, return_tensors='pt')
33
+ input = {key: value.to(device) for key, value in input.items()}
34
+
35
+ sent_len = input['input_ids'].shape[1]
36
+
37
+ positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0]
38
+ negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0]
39
+
40
+ positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]]
41
+ negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]]
42
+
43
+ scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon)
44
+
45
+ candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1
46
+
47
+ new_variants = []
48
+
49
+ template = input['input_ids'][0].cpu().numpy()
50
+
51
+ for candidate in candidates:
52
+ top_replaces = torch.argsort(positive_probs[candidate])[-k_best:]
53
+ for replace in top_replaces:
54
+ new_variant = template.copy()
55
+ new_variant[candidate] = replace
56
+ new_variants.append(new_variant)
57
+
58
+ return [tokenizer.decode(variant[1:-1]) for variant in new_variants]
59
+
60
+
61
+ def get_positiveness(sentence):
62
+ bert_classifier.eval()
63
+ with torch.no_grad():
64
+ tokenized = tokenizer(sentence, return_tensors='pt')
65
+ tokenized = {key: value.to(device) for key, value in tokenized.items()}
66
+ res = bert_classifier(**tokenized)
67
+ return res.logits[0][0].item()
68
+
69
+
70
+ def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8):
71
+ variants = [sentence]
72
+ for _ in range(num_iterations):
73
+ suggestions = []
74
+ for variant in variants:
75
+ suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best))
76
+
77
+ # don't forget old variants to forget about num_iterations tuning
78
+ variants.extend(suggestions)
79
+ scores = [get_positiveness(suggestion) for suggestion in variants]
80
+ scores = np.array(scores)
81
+ beam = np.argsort(scores)[-beam_size:]
82
+
83
+ new_variants = []
84
+ for ind in beam:
85
+ new_variants.append(variants[ind])
86
+ variants = new_variants
87
+
88
+ return variants[0]
89
+
90
+
91
+ def process_review(review):
92
+ num_iterations = int(st.session_state.num_iterations)
93
+ beam_size = int(st.session_state.beam_size)
94
+ num_tokens = int(st.session_state.num_tokens)
95
+ k_best = int(st.session_state.k_best)
96
+ return beam_rewrite(review,
97
+ num_iterations=num_iterations,
98
+ num_tokens=num_tokens,
99
+ k_best=k_best,
100
+ beam_size=beam_size)
101
+
102
+
103
+ st.markdown("# INGSOC-approved service for posting Your honest reviews!")
104
+ st.text_input("Your honest review: ", key='review')
105
+
106
+ if st.session_state.review:
107
+ with st.spinner('Wait for it...'):
108
+ review = process_review(st.session_state.review)
109
+ review = review.capitalize()
110
+ st.markdown("### Here is Your honest review:")
111
+ st.markdown(f'## "{review}"')
112
+
113
+
114
+ with st.expander("Only for class A412C citzens"):
115
+ st.number_input('Number of beam search iterations: ',
116
+ min_value=1, max_value=20, value=5, key='num_iterations')
117
+
118
+ st.number_input('Beam size: ',
119
+ min_value=1, max_value=20, value=8, key='beam_size')
120
+
121
+ st.number_input('Number of tokens tested each iteration: ',
122
+ min_value=1, max_value=20, value=2, key='num_tokens')
123
+
124
+ st.number_input('Number of best replacements tested each iteration: ',
125
+ min_value=1, max_value=20, value=3, key='k_best')
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ pytorch
2
+ transformers
3
+ numpy
4
+ streamlit