CAGmllab commited on
Commit
30c6353
1 Parent(s): 682d7c8

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from styleformer import Styleformer
2
+ import streamlit as st
3
+ import numpy as np
4
+ import json
5
+
6
+ class Demo:
7
+ def __init__(self):
8
+ st.set_page_config(
9
+ page_title="Styleformer Demo",
10
+ initial_sidebar_state="expanded"
11
+ )
12
+ self.style_map = {
13
+ #key : (name , style_num)
14
+ 'ctf': ('Casual to Formal', 0),
15
+ 'ftc': ('Formal to Casual', 1),
16
+ 'atp': ('Active to Passive', 2),
17
+ 'pta': ('Passive to Active', 3)
18
+ }
19
+ self.inference_map = {
20
+ 0: 'Regular model on CPU',
21
+ 1: 'Regular model on GPU',
22
+ 2: 'Quantized model on CPU'
23
+ }
24
+ with open("streamlit_examples.json") as f:
25
+ self.examples = json.load(f)
26
+
27
+ @st.cache(show_spinner=False, suppress_st_warning=True, allow_output_mutation=True)
28
+ def load_sf(self, style=0):
29
+ sf = Styleformer(style = style)
30
+ return sf
31
+
32
+ def main(self):
33
+ github_repo = 'https://github.com/PrithivirajDamodaran/Styleformer'
34
+ st.title("Styleformer")
35
+ st.write(f'GitHub Link - [{github_repo}]({github_repo})')
36
+ st.write('A Neural Language Style Transfer framework to transfer natural language text smoothly between fine-grained language styles like formal/casual, active/passive, and many more')
37
+
38
+ style_key = st.sidebar.selectbox(
39
+ label='Choose Style',
40
+ options=list(self.style_map.keys()),
41
+ format_func=lambda x:self.style_map[x][0]
42
+ )
43
+ exp = st.sidebar.beta_expander('Knobs', expanded=True)
44
+ with exp:
45
+ inference_on = exp.selectbox(
46
+ label='Inference on',
47
+ options=list(self.inference_map.keys()),
48
+ format_func=lambda x:self.inference_map[x]
49
+ )
50
+ quality_filter = exp.slider(
51
+ label='Quality filter',
52
+ min_value=0.5,
53
+ max_value=0.99,
54
+ value=0.95
55
+ )
56
+ max_candidates = exp.number_input(
57
+ label='Max candidates',
58
+ min_value=1,
59
+ max_value=20,
60
+ value=5
61
+ )
62
+ with st.spinner('Loading model..'):
63
+ sf = self.load_sf(self.style_map[style_key][1])
64
+ input_text = st.selectbox(
65
+ label="Choose an example",
66
+ options=self.examples[style_key]
67
+ )
68
+ input_text = st.text_input(
69
+ label="Input text",
70
+ value=input_text
71
+ )
72
+
73
+ if input_text.strip():
74
+ result = sf.transfer(input_text, inference_on=inference_on, quality_filter=quality_filter, max_candidates=max_candidates)
75
+ st.markdown(f'#### Output:')
76
+ st.write('')
77
+ if result:
78
+ st.success(result)
79
+ else:
80
+ st.info('No good quality transfers available !')
81
+ else:
82
+ st.warning("Please select/enter text to proceed")
83
+
84
+
85
+
86
+ if __name__ == "__main__":
87
+ obj = Demo()
88
+ obj.main()
89
+
90
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ sentencepiece
3
+ python-Levenshtein
4
+ fuzzywuzzy
streamlit_examples.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ctf": [
3
+ "I am quitting my job",
4
+ "Jimmy is on crack and can't trust him",
5
+ "What do guys do to show that they like a gal?",
6
+ "i loooooooooooooooooooooooove going to the movies.",
7
+ "That movie was fucking awesome",
8
+ "My mom is doing fine",
9
+ "That was funny LOL",
10
+ "It's piece of cake, we can do it",
11
+ "btw - ur avatar looks familiar",
12
+ "who gives a crap?",
13
+ "Howdy Lucy! been ages since we last met.",
14
+ "Dude, this car's dope!",
15
+ "She's my bestie from college",
16
+ "I kinda have a feeling that he has a crush on you.",
17
+ "OMG! It's finger-lickin' good."
18
+ ],
19
+ "ftc": [
20
+ "That really is quite impressive.",
21
+ "Would you please allow me to make a suggestion?",
22
+ "Good morning! How are you?",
23
+ "I would like to apologise for any inconvenience caused."
24
+ ],
25
+ "atp": [
26
+ "India won ICC Cricket World Cup 2011",
27
+ "Daya opened the door.",
28
+ "The cat killed the mouse",
29
+ "He has not completed the work.",
30
+ "I have made some cakes.",
31
+ "They are eating apples.",
32
+ "The wedding planner is making all the reservations.",
33
+ "PM declared nation-wide lockdown"
34
+ ],
35
+ "pta": [
36
+ "The lion was killed by the hunter.",
37
+ "He was given a book for his birthday.",
38
+ "The house will be cleaned by me every Saturday.",
39
+ "The Grand Canyon is visited by thousands of tourists every year.",
40
+ "All the reservations are being made by the wedding planner.",
41
+ "Money was generously donated to the homeless shelter by him"
42
+ ]
43
+ }
styleformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from styleformer.styleformer import Styleformer
2
+ from styleformer.adequacy import Adequacy
styleformer/adequacy.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Adequacy():
2
+
3
+ def __init__(self, model_tag='prithivida/parrot_adequacy_model'):
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ self.adequacy_model = AutoModelForSequenceClassification.from_pretrained(model_tag)
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_tag)
7
+
8
+ def filter(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
9
+ top_adequacy_phrases = []
10
+ for para_phrase in para_phrases:
11
+ x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
12
+ self.adequacy_model = self.adequacy_model.to(device)
13
+ logits = self.adequacy_model(**x).logits
14
+ probs = logits.softmax(dim=1)
15
+ prob_label_is_true = probs[:,1]
16
+ adequacy_score = prob_label_is_true.item()
17
+ if adequacy_score >= adequacy_threshold:
18
+ top_adequacy_phrases.append(para_phrase)
19
+ return top_adequacy_phrases
20
+
21
+
22
+ def score(self, input_phrase, para_phrases, adequacy_threshold, device="cpu"):
23
+ adequacy_scores = {}
24
+ for para_phrase in para_phrases:
25
+ x = self.tokenizer(input_phrase, para_phrase, return_tensors='pt', max_length=128, truncation=True)
26
+ x = x.to(device)
27
+ self.adequacy_model = self.adequacy_model.to(device)
28
+ logits = self.adequacy_model(**x).logits
29
+ probs = logits.softmax(dim=1)
30
+ prob_label_is_true = probs[:,1]
31
+ adequacy_score = prob_label_is_true.item()
32
+ if adequacy_score >= adequacy_threshold:
33
+ adequacy_scores[para_phrase] = adequacy_score
34
+ return adequacy_scores
styleformer/demo.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from styleformer import Styleformer
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+ import torch
5
+
6
+ def set_seed(seed):
7
+ torch.manual_seed(seed)
8
+ if torch.cuda.is_available():
9
+ torch.cuda.manual_seed_all(seed)
10
+
11
+ set_seed(1234)
12
+
13
+ source_sentences = [
14
+ "I am quitting my job",
15
+ "Jimmy is on crack and can't trust him",
16
+ "What do guys do to show that they like a gal?",
17
+ "i loooooooooooooooooooooooove going to the movies.",
18
+ "That movie was fucking awesome",
19
+ "My mom is doing fine",
20
+ "That was funny LOL",
21
+ "It's piece of cake, we can do it",
22
+ "btw - ur avatar looks familiar",
23
+ "who gives a crap?",
24
+ "Howdy Lucy! been ages since we last met.",
25
+ "Dude, this car's dope!",
26
+ "She's my bestie from college",
27
+ "I kinda have a feeling that he has a crush on you.",
28
+ "OMG! It's finger-lickin' good.",
29
+ ]
30
+
31
+ # style = [0=Casual to Formal, 1=Formal to Casual, 2=Active to Passive, 3=Passive to Active etc..]
32
+ sf = Styleformer(style = 0)
33
+
34
+ for source_sentence in source_sentences:
35
+ # inference_on = [0=Regular model On CPU, 1= Regular model On GPU, 2=Quantized model On CPU]
36
+ target_sentence = sf.transfer(source_sentence, inference_on=1, quality_filter=0.95, max_candidates=5)
37
+ print("[Informal] ", source_sentence)
38
+ if target_sentence is not None:
39
+ print("[Formal] ",target_sentence)
40
+ else:
41
+ print("No good quality transfers available !")
42
+ print("-" *100)
styleformer/styleformer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Styleformer():
2
+
3
+ def __init__(
4
+ self,
5
+ style=0,
6
+ ctf_model_tag="prithivida/informal_to_formal_styletransfer",
7
+ ftc_model_tag="prithivida/formal_to_informal_styletransfer",
8
+ atp_model_tag="prithivida/active_to_passive_styletransfer",
9
+ pta_model_tag="prithivida/passive_to_active_styletransfer",
10
+ adequacy_model_tag="prithivida/parrot_adequacy_model",
11
+ ):
12
+ from transformers import AutoTokenizer
13
+ from transformers import AutoModelForSeq2SeqLM
14
+ from styleformer import Adequacy
15
+
16
+ self.style = style
17
+ self.adequacy = adequacy_model_tag and Adequacy(model_tag=adequacy_model_tag)
18
+ self.model_loaded = False
19
+
20
+ if self.style == 0:
21
+ self.ctf_tokenizer = AutoTokenizer.from_pretrained(ctf_model_tag, use_auth_token=False)
22
+ self.ctf_model = AutoModelForSeq2SeqLM.from_pretrained(ctf_model_tag, use_auth_token=False)
23
+ print("Casual to Formal model loaded...")
24
+ self.model_loaded = True
25
+ elif self.style == 1:
26
+ self.ftc_tokenizer = AutoTokenizer.from_pretrained(ftc_model_tag, use_auth_token=False)
27
+ self.ftc_model = AutoModelForSeq2SeqLM.from_pretrained(ftc_model_tag, use_auth_token=False)
28
+ print("Formal to Casual model loaded...")
29
+ self.model_loaded = True
30
+ elif self.style == 2:
31
+ self.atp_tokenizer = AutoTokenizer.from_pretrained(atp_model_tag, use_auth_token=False)
32
+ self.atp_model = AutoModelForSeq2SeqLM.from_pretrained(atp_model_tag, use_auth_token=False)
33
+ print("Active to Passive model loaded...")
34
+ self.model_loaded = True
35
+ elif self.style == 3:
36
+ self.pta_tokenizer = AutoTokenizer.from_pretrained(pta_model_tag, use_auth_token=False)
37
+ self.pta_model = AutoModelForSeq2SeqLM.from_pretrained(pta_model_tag, use_auth_token=False)
38
+ print("Passive to Active model loaded...")
39
+ self.model_loaded = True
40
+ else:
41
+ print("Only CTF, FTC, ATP and PTA are supported in the pre-release...stay tuned")
42
+
43
+ def transfer(self, input_sentence, inference_on=-1, quality_filter=0.95, max_candidates=5):
44
+ if self.model_loaded:
45
+ if inference_on == -1:
46
+ device = "cpu"
47
+ elif inference_on >= 0 and inference_on < 999:
48
+ device = "cuda:" + str(inference_on)
49
+ else:
50
+ device = "cpu"
51
+ print("Onnx + Quantisation is not supported in the pre-release...stay tuned.")
52
+
53
+ if self.style == 0:
54
+ output_sentence = self._casual_to_formal(input_sentence, device, quality_filter, max_candidates)
55
+ return output_sentence
56
+ elif self.style == 1:
57
+ output_sentence = self._formal_to_casual(input_sentence, device, quality_filter, max_candidates)
58
+ return output_sentence
59
+ elif self.style == 2:
60
+ output_sentence = self._active_to_passive(input_sentence, device)
61
+ return output_sentence
62
+ elif self.style == 3:
63
+ output_sentence = self._passive_to_active(input_sentence, device)
64
+ return output_sentence
65
+ else:
66
+ print("Models aren't loaded for this style, please use the right style during init")
67
+
68
+
69
+ def _formal_to_casual(self, input_sentence, device, quality_filter, max_candidates):
70
+ ftc_prefix = "transfer Formal to Casual: "
71
+ src_sentence = input_sentence
72
+ input_sentence = ftc_prefix + input_sentence
73
+ input_ids = self.ftc_tokenizer.encode(input_sentence, return_tensors='pt')
74
+ self.ftc_model = self.ftc_model.to(device)
75
+ input_ids = input_ids.to(device)
76
+
77
+ preds = self.ftc_model.generate(
78
+ input_ids,
79
+ do_sample=True,
80
+ max_length=32,
81
+ top_k=50,
82
+ top_p=0.95,
83
+ early_stopping=True,
84
+ num_return_sequences=max_candidates)
85
+
86
+ gen_sentences = set()
87
+ for pred in preds:
88
+ gen_sentences.add(self.ftc_tokenizer.decode(pred, skip_special_tokens=True).strip())
89
+
90
+ adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
91
+ ranked_sentences = sorted(adequacy_scored_phrases.items(), key = lambda x:x[1], reverse=True)
92
+ if len(ranked_sentences) > 0:
93
+ return ranked_sentences[0][0]
94
+ else:
95
+ return None
96
+
97
+ def _casual_to_formal(self, input_sentence, device, quality_filter, max_candidates):
98
+ ctf_prefix = "transfer Casual to Formal: "
99
+ src_sentence = input_sentence
100
+ input_sentence = ctf_prefix + input_sentence
101
+ input_ids = self.ctf_tokenizer.encode(input_sentence, return_tensors='pt')
102
+ self.ctf_model = self.ctf_model.to(device)
103
+ input_ids = input_ids.to(device)
104
+
105
+ preds = self.ctf_model.generate(
106
+ input_ids,
107
+ do_sample=True,
108
+ max_length=32,
109
+ top_k=50,
110
+ top_p=0.95,
111
+ early_stopping=True,
112
+ num_return_sequences=max_candidates)
113
+
114
+ gen_sentences = set()
115
+ for pred in preds:
116
+ gen_sentences.add(self.ctf_tokenizer.decode(pred, skip_special_tokens=True).strip())
117
+
118
+ adequacy_scored_phrases = self.adequacy.score(src_sentence, list(gen_sentences), quality_filter, device)
119
+ ranked_sentences = sorted(adequacy_scored_phrases.items(), key = lambda x:x[1], reverse=True)
120
+ if len(ranked_sentences) > 0:
121
+ return ranked_sentences[0][0]
122
+ else:
123
+ return None
124
+
125
+ def _active_to_passive(self, input_sentence, device):
126
+ atp_prefix = "transfer Active to Passive: "
127
+ src_sentence = input_sentence
128
+ input_sentence = atp_prefix + input_sentence
129
+ input_ids = self.atp_tokenizer.encode(input_sentence, return_tensors='pt')
130
+ self.atp_model = self.atp_model.to(device)
131
+ input_ids = input_ids.to(device)
132
+
133
+ preds = self.atp_model.generate(
134
+ input_ids,
135
+ do_sample=True,
136
+ max_length=32,
137
+ top_k=50,
138
+ top_p=0.95,
139
+ early_stopping=True,
140
+ num_return_sequences=1)
141
+
142
+ return self.atp_tokenizer.decode(preds[0], skip_special_tokens=True).strip()
143
+
144
+ def _passive_to_active(self, input_sentence, device):
145
+ pta_prefix = "transfer Passive to Active: "
146
+ src_sentence = input_sentence
147
+ input_sentence = pta_prefix + input_sentence
148
+ input_ids = self.pta_tokenizer.encode(input_sentence, return_tensors='pt')
149
+ self.pta_model = self.pta_model.to(device)
150
+ input_ids = input_ids.to(device)
151
+
152
+ preds = self.pta_model.generate(
153
+ input_ids,
154
+ do_sample=True,
155
+ max_length=32,
156
+ top_k=50,
157
+ top_p=0.95,
158
+ early_stopping=True,
159
+ num_return_sequences=1)
160
+
161
+ return self.pta_tokenizer.decode(preds[0], skip_special_tokens=True).strip()
162
+
163
+