kruntuid commited on
Commit
d22ab1f
1 Parent(s): 4d415e1

ysda next commit

Browse files
Files changed (3) hide show
  1. GD_download.py +31 -0
  2. app.py +91 -2
  3. requirements.txt +3 -0
GD_download.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
2
+ import requests
3
+
4
+ def download_file_from_google_drive(id, destination):
5
+ URL = "https://docs.google.com/uc?export=download"
6
+
7
+ session = requests.Session()
8
+
9
+ response = session.get(URL, params = { 'id' : id }, stream = True)
10
+ token = get_confirm_token(response)
11
+
12
+ if token:
13
+ params = { 'id' : id, 'confirm' : token }
14
+ response = session.get(URL, params = params, stream = True)
15
+
16
+ save_response_content(response, destination)
17
+
18
+ def get_confirm_token(response):
19
+ for key, value in response.cookies.items():
20
+ if key.startswith('download_warning'):
21
+ return value
22
+
23
+ return None
24
+
25
+ def save_response_content(response, destination):
26
+ CHUNK_SIZE = 32768
27
+
28
+ with open(destination, "wb") as f:
29
+ for chunk in response.iter_content(CHUNK_SIZE):
30
+ if chunk: # filter out keep-alive new chunks
31
+ f.write(chunk)
app.py CHANGED
@@ -1,4 +1,93 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
  import streamlit as st
2
+ from pathlib import Path
3
+ import torch
4
+ from transformers import BertTokenizer
5
+
6
+
7
+ @st.cache
8
+ def get_tokenizer():
9
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
+ return tokenizer
11
+
12
+ @st.cache
13
+ def load_model_bert_mlm_positive():
14
+ f_checkpoint = Path("bert_mlm_positive.model")
15
+
16
+ if not f_checkpoint.exists():
17
+ with st.spinner("Downloading bert_mlm_positive... this may take awhile! \n Don't stop it!"):
18
+ from GD_download import download_file_from_google_drive
19
+ cloud_model_location = "12Gvgv6zaOLJ8oyYXVB5_GYNEfvsjudG_"
20
+ download_file_from_google_drive(cloud_model_location, f_checkpoint)
21
+
22
+ model = torch.load(f_checkpoint, map_location=torch.device('cpu'))
23
+ model.eval()
24
+ return model
25
+
26
+ @st.cache
27
+ def load_model_model_seq_classify():
28
+ f_checkpoint = Path("model_seq_classify.model")
29
+
30
+ if not f_checkpoint.exists():
31
+ with st.spinner("Downloading model_seq_classify... this may take awhile! \n Don't stop it!"):
32
+ from GD_download import download_file_from_google_drive
33
+ cloud_model_location = "13DwlCIM6aYc4WeOCIRqdGy-U0LGc8f0B"
34
+ download_file_from_google_drive(cloud_model_location, f_checkpoint)
35
+
36
+ model = torch.load(f_checkpoint, map_location=torch.device('cpu'))
37
+ model.eval()
38
+ return model
39
+
40
+
41
+ def get_replacements_beamsearch(tokenizer, bert_mlm_positive, seq_classify_model, sentence: str, num_candidates=3):
42
+ sentence_ix = tokenizer(sentence, return_tensors='pt')
43
+
44
+ tokens = [tokenizer.decode([t]) for t in sentence_ix['input_ids'].cpu().numpy()[0]]
45
+
46
+ length = len(sentence_ix['input_ids'][0])
47
+
48
+ current = [(tokens, 0)]
49
+ for ix in range(1,length-1):
50
+
51
+ new_current = []
52
+ for item in current:
53
+ sent = " ".join(item[0][1:-1])
54
+ prob_seq = item[1]
55
+ new_current.append(item)
56
+
57
+ sent_ix = tokenizer(sent, return_tensors='pt')
58
+ logits_positive = bert_mlm_positive(**sent_ix).logits
59
+ probs_positive = logits_positive.softmax(dim=-1)[0, ix]
60
+ indices = torch.argsort(probs_positive, descending=True)
61
+
62
+ for cand_ix in range(num_candidates):
63
+ token_id = indices[cand_ix]
64
+ new_seq = item[0].copy()
65
+ new_seq[ix] = tokenizer.decode([token_id])
66
+
67
+ logits = seq_classify_model(**tokenizer(" ".join(new_seq[1:-1]), return_tensors='pt')).logits
68
+ prob = logits.softmax(dim=-1)[0][1]
69
+
70
+ new_current.append((new_seq, prob))
71
+
72
+ current = sorted(new_current, key=lambda x: -x[1])[:num_candidates]
73
+
74
+ return [" ".join(item[0][1:-1]) for item in current]
75
+
76
+
77
+
78
+ negative_phrase = st.text_input("Input negative phrase")
79
+ num_candidates = st.slider("Number of candidates", min_value=1, max_value=5)
80
+
81
+
82
+ if negative_phrase:
83
+ bert_mlm_positive = load_model_bert_mlm_positive()
84
+ model_seq_classify = load_model_model_seq_classify()
85
+
86
+ ret = get_replacements_beamsearch(get_tokenizer(), bert_mlm_positive,
87
+ model_seq_classify, negative_phrase, num_candidates=num_candidates)
88
+
89
+ st.caption("Output positive phrases:")
90
+ for i in range(len(ret)):
91
+ st.caption(ret[i])
92
+
93
 
 
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit==1.2.0
2
+ torch==1.10.0
3
+ transformers==4.11.3