File size: 4,708 Bytes
44e043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
import torch
import numpy as np

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

st.set_page_config(page_title="INGSOC review poster", layout="centered")

@st.cache
def load_models():
	tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
	bert_mlm_positive = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
	bert_mlm_negative = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
	bert_classifier = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True).to(device).train(False)
	bert_mlm_positive.load_state_dict(torch.load('bert_mlm_positive/pytorch_model.bin', map_location=device))
	bert_mlm_negative.load_state_dict(torch.load('bert_mlm_negative/pytorch_model.bin', map_location=device))
	bert_classifier.load_state_dict(torch.load('bert_classifier/pytorch_model.bin', map_location=device))
	return tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier


tokenizer, bert_mlm_positive, bert_mlm_negative, bert_classifier = load_models()


def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
	"""
	- split the sentence into tokens using the INGSOC-approved BERT tokenizer
	- find :num_tokens: tokens with the highest ratio (see above)
	- replace them with :k_best: words according to bert_mlm_positive
	:return: a list of all possible strings (up to k_best * num_tokens)
	"""
	input = tokenizer(sentence, return_tensors='pt')
	input = {key: value.to(device) for key, value in input.items()}

	sent_len = input['input_ids'].shape[1]

	positive_probs = bert_mlm_positive(**input).logits.softmax(dim=-1)[0]
	negative_probs = bert_mlm_negative(**input).logits.softmax(dim=-1)[0]

	positive_token_probs = positive_probs[torch.arange(sent_len), input['input_ids'][0]]
	negative_token_probs = negative_probs[torch.arange(sent_len), input['input_ids'][0]]

	scores = (positive_token_probs + epsilon) / (negative_token_probs + epsilon)

	candidates = torch.argsort(scores[1:-1])[:num_tokens] + 1

	new_variants = []

	template = input['input_ids'][0].cpu().numpy()

	for candidate in candidates:
		top_replaces = torch.argsort(positive_probs[candidate])[-k_best:]
		for replace in top_replaces:
			new_variant = template.copy()
			new_variant[candidate] = replace
			new_variants.append(new_variant)

	return [tokenizer.decode(variant[1:-1]) for variant in new_variants]


def get_positiveness(sentence):
	bert_classifier.eval()
	with torch.no_grad():
		tokenized = tokenizer(sentence, return_tensors='pt')
		tokenized = {key: value.to(device) for key, value in tokenized.items()}
		res = bert_classifier(**tokenized)
		return res.logits[0][0].item()


def beam_rewrite(sentence, num_iterations=5, num_tokens=2, k_best=3, beam_size=8):
	variants = [sentence]
	for _ in range(num_iterations):
		suggestions = []
		for variant in variants:
			suggestions.extend(get_replacements(variant, num_tokens=num_tokens, k_best=k_best))

		# don't forget old variants to forget about num_iterations tuning
		variants.extend(suggestions)
		scores = [get_positiveness(suggestion) for suggestion in variants]
		scores = np.array(scores)
		beam = np.argsort(scores)[-beam_size:]

		new_variants = []
		for ind in beam:
			new_variants.append(variants[ind])
		variants = new_variants

	return variants[0]


def process_review(review):
	num_iterations = int(st.session_state.num_iterations)
	beam_size = int(st.session_state.beam_size)
	num_tokens = int(st.session_state.num_tokens)
	k_best = int(st.session_state.k_best)
	return beam_rewrite(review, 
		num_iterations=num_iterations,
		num_tokens=num_tokens,
		k_best=k_best,
		beam_size=beam_size)


st.markdown("# INGSOC-approved service for posting Your honest reviews!")
st.text_input("Your honest review: ", key='review')

if st.session_state.review:
	with st.spinner('Wait for it...'):
		review = process_review(st.session_state.review)
		review = review.capitalize()
	st.markdown("### Here is Your honest review:")
	st.markdown(f'## "{review}"')


with st.expander("Only for class A412C citzens"):
     st.number_input('Number of beam search iterations: ', 
     	min_value=1, max_value=20, value=5, key='num_iterations')

     st.number_input('Beam size: ', 
     	min_value=1, max_value=20, value=8, key='beam_size')

     st.number_input('Number of tokens tested each iteration: ', 
     	min_value=1, max_value=20, value=2, key='num_tokens')

     st.number_input('Number of best replacements tested each iteration: ', 
     	min_value=1, max_value=20, value=3, key='k_best')