Spaces:
Runtime error
Runtime error
Commit
•
44e043b
0
Parent(s):
Application code
Browse files- .gitattributes +30 -0
- README.md +37 -0
- app.py +125 -0
- requirements.txt +4 -0
.gitattributes
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
bert_mlm_positive/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
29 |
+
bert_mlm_negative/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
30 |
+
bert_classifier/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Ingsoc_censor
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: gray
|
6 |
+
sdk: streamlit
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
# Configuration
|
12 |
+
|
13 |
+
`title`: _string_
|
14 |
+
Display title for the Space
|
15 |
+
|
16 |
+
`emoji`: _string_
|
17 |
+
Space emoji (emoji-only character allowed)
|
18 |
+
|
19 |
+
`colorFrom`: _string_
|
20 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
+
|
22 |
+
`colorTo`: _string_
|
23 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
+
|
25 |
+
`sdk`: _string_
|
26 |
+
Can be either `gradio` or `streamlit`
|
27 |
+
|
28 |
+
`sdk_version` : _string_
|
29 |
+
Only applicable for `streamlit` SDK.
|
30 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
31 |
+
|
32 |
+
`app_file`: _string_
|
33 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
34 |
+
Path is relative to the root of the repository.
|
35 |
+
|
36 |
+
`pinned`: _boolean_
|
37 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
7 |
+
|
8 |
+
st.set_page_config(page_title="INGSOC review poster", layout="centered")
|
9 |
+
|
10 |
+
@st.cache
|
11 |
+
def load_models():
|
12 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
13 |
+
bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
|
14 |
+
bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
|
15 |
+
bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
|
16 |
+
bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device))
|
17 |
+
bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device))
|
18 |
+
bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device))
|
19 |
+
return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier
|
20 |
+
|
21 |
+
|
22 |
+
tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()
|
23 |
+
|
24 |
+
|
25 |
+
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
|
26 |
+
"""
|
27 |
+
- split the sentence into tokens using the INGSOC-approved BERT tokenizer
|
28 |
+
- find :num_tokens: tokens with the highest ratio (see above)
|
29 |
+
- replace them with :k_best: words according to bert_mlm_positive
|
30 |
+
:return: a list of all possible strings (up to k_best * num_tokens)
|
31 |
+
"""
|
32 |
+
input = tokenizer(sentence, return_tensors='pt')
|
33 |
+
input = {key: value.to(device) for key, value in input.items()}
|
34 |
+
|
35 |
+
sent_len = input['input_ids'].shape[1]
|
36 |
+
|
37 |
+
positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0]
|
38 |
+
negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0]
|
39 |
+
|
40 |
+
positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]]
|
41 |
+
negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]]
|
42 |
+
|
43 |
+
scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon)
|
44 |
+
|
45 |
+
candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1
|
46 |
+
|
47 |
+
new_variants = []
|
48 |
+
|
49 |
+
template = input['input_ids'][0].cpu().numpy()
|
50 |
+
|
51 |
+
for candidate in candidates:
|
52 |
+
top_replaces = torch.argsort(positive_probs[candidate])[-k_best:]
|
53 |
+
for replace in top_replaces:
|
54 |
+
new_variant = template.copy()
|
55 |
+
new_variant[candidate] = replace
|
56 |
+
new_variants.append(new_variant)
|
57 |
+
|
58 |
+
return [tokenizer.decode(variant[1:-1]) for variant in new_variants]
|
59 |
+
|
60 |
+
|
61 |
+
def get_positiveness(sentence):
|
62 |
+
bert_classifier.eval()
|
63 |
+
with torch.no_grad():
|
64 |
+
tokenized = tokenizer(sentence, return_tensors='pt')
|
65 |
+
tokenized = {key: value.to(device) for key, value in tokenized.items()}
|
66 |
+
res = bert_classifier(**tokenized)
|
67 |
+
return res.logits[0][0].item()
|
68 |
+
|
69 |
+
|
70 |
+
def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8):
|
71 |
+
variants = [sentence]
|
72 |
+
for _ in range(num_iterations):
|
73 |
+
suggestions = []
|
74 |
+
for variant in variants:
|
75 |
+
suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best))
|
76 |
+
|
77 |
+
# don't forget old variants to forget about num_iterations tuning
|
78 |
+
variants.extend(suggestions)
|
79 |
+
scores = [get_positiveness(suggestion) for suggestion in variants]
|
80 |
+
scores = np.array(scores)
|
81 |
+
beam = np.argsort(scores)[-beam_size:]
|
82 |
+
|
83 |
+
new_variants = []
|
84 |
+
for ind in beam:
|
85 |
+
new_variants.append(variants[ind])
|
86 |
+
variants = new_variants
|
87 |
+
|
88 |
+
return variants[0]
|
89 |
+
|
90 |
+
|
91 |
+
def process_review(review):
|
92 |
+
num_iterations = int(st.session_state.num_iterations)
|
93 |
+
beam_size = int(st.session_state.beam_size)
|
94 |
+
num_tokens = int(st.session_state.num_tokens)
|
95 |
+
k_best = int(st.session_state.k_best)
|
96 |
+
return beam_rewrite(review,
|
97 |
+
num_iterations=num_iterations,
|
98 |
+
num_tokens=num_tokens,
|
99 |
+
k_best=k_best,
|
100 |
+
beam_size=beam_size)
|
101 |
+
|
102 |
+
|
103 |
+
st.markdown("# INGSOC-approved service for posting Your honest reviews!")
|
104 |
+
st.text_input("Your honest review: ", key='review')
|
105 |
+
|
106 |
+
if st.session_state.review:
|
107 |
+
with st.spinner('Wait for it...'):
|
108 |
+
review = process_review(st.session_state.review)
|
109 |
+
review = review.capitalize()
|
110 |
+
st.markdown("### Here is Your honest review:")
|
111 |
+
st.markdown(f'## "{review}"')
|
112 |
+
|
113 |
+
|
114 |
+
with st.expander("Only for class A412C citzens"):
|
115 |
+
st.number_input('Number of beam search iterations: ',
|
116 |
+
min_value=1, max_value=20, value=5, key='num_iterations')
|
117 |
+
|
118 |
+
st.number_input('Beam size: ',
|
119 |
+
min_value=1, max_value=20, value=8, key='beam_size')
|
120 |
+
|
121 |
+
st.number_input('Number of tokens tested each iteration: ',
|
122 |
+
min_value=1, max_value=20, value=2, key='num_tokens')
|
123 |
+
|
124 |
+
st.number_input('Number of best replacements tested each iteration: ',
|
125 |
+
min_value=1, max_value=20, value=3, key='k_best')
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
1 |
+
pytorch
|
2 |
+
transformers
|
3 |
+
numpy
|
4 |
+
streamlit
|