salexashenko commited on
Commit
3a555af
1 Parent(s): 7cd979d
Files changed (2) hide show
  1. app.py +290 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+
3
+ from spacy.cli import download
4
+ download("en_core_web_sm")
5
+ nltk.download('stopwords')
6
+
7
+ from nltk.corpus import stopwords
8
+
9
+ en_stopwords = set(list(stopwords.words('english')) +
10
+ ['summary', 'synopsis', 'overview', 'list', 'good', 'will', 'why', 'talk', 'long', 'above', 'looks', 'face', 'men', 'years', 'can', 'both', 'have', 'keep', 'yeah', 'said', 'bring', 'done', 'was', 'when', 'ask', 'now', 'very', 'kind', 'they', 'told', 'tell', 'ever', 'kill', 'hold', 'that', 'below', 'bit', 'knew', 'haven', 'few', 'place', 'could', 'says', 'huh', 'job', 'also', 'ain', 'may', 'heart', 'boy', 'with', 'over', 'son', 'else', 'found', 'see', 'any', 'phone', 'hasn', 'saw', 'these', 'maybe', 'into', 'thing', 'mom', 'god', 'old', 'aren', 'mustn', 'out', 'about', 'guy', 'each', 'most', 'like', 'then', 'wasn', 'being', 'all', 'door', 'look', 'run', 'sorry', 'again', 'won', 'man', 'gone', 'them', 'ago', 'doesn', 'gonna', 'girl', 'feel', 'work', 'much', 'hope', 'never', 'woman', 'went', 'lot', 'what', 'start', 'only', 'play', 'too', 'dad', 'going', 'yours', 'wrong', 'fine', 'made', 'one', 'want', 'isn', 'our', 'true', 'room', 'wanna', 'are', 'idea', 'sure', 'find', 'same', 'doing', 'off', 'put', 'turn', 'come', 'house', 'think', 'meet', 'hers', 'gotta', 'nor', 'away', 'leave', 'car', 'used', 'happy', 'the', 'care', 'seen', 'she', 'not', 'were', 'ours', 'their', 'first', 'world', 'lost', 'make', 'big', 'left', 'miss', 'shan', 'did', 'thank', 'ready', 'those', 'give', 'next', 'came', 'who', 'mind', 'does', 'right', 'her', 'let', 'didn', 'open', 'has', 'show', 'wife', 'yet', 'got', 'know', 'whole', 'some', 'such', 'alone', 'baby', 'him', 'nice', 'bad', 'move', 'new', 'dead', 'three', 'weren', 'whom', 'well', 'get', 'which', 'end', 'you', 'than', 'while', 'last', 'once', 'sir', 'from', 'need', 'wait', 'days', 'how', 'don', 'heard', 'own', 'hear', 'where', 'hey', 'okay', 'just', 'until', 'your', 'there', 'this', 'more', 'been', 'his', 'under', 'mean', 'might', 'here', 'its', 'but', 'stay', 'yes', 'guess', 'even', 'guys', 'hard', 'hadn', 'live', 'stop', 'took', 'still', 'other', 'since', 'every', 'needn', 'way', 'name', 'two', 'back', 'and', 'hello', 'head', 'use', 'must', 'for', 'life', 'die', 'day', 'down', 'wants', 'after', 'say', 'try', 'had', 'night']
11
+ )
12
+
13
+ import multiprocessing
14
+ import os
15
+ from whoosh.analysis import StemmingAnalyzer
16
+ from whoosh.index import create_in
17
+ from whoosh.fields import *
18
+ import whoosh.index as whoosh_index
19
+ import tqdm
20
+
21
+ def get_content_ext(content, bm25_field):
22
+ return content
23
+
24
+ def yield_line_by_line(file):
25
+ with open(file) as input:
26
+ for l in input:
27
+ yield l
28
+
29
+ def recreate_bm25_idx(content_data_store, bm25_field="search", idx_dir=".", auto_create_bm25_idx=False, idxs=None, use_tqdm=True):
30
+ if type(content_data_store) is str:
31
+ content_data_store = yield_line_by_line(content_data_store)
32
+ schema = Schema(id=ID(stored=True), content=TEXT(analyzer=StemmingAnalyzer()))
33
+ #TODO determine how to clear out the whoosh index besides rm -rf _M* MAIN*
34
+ os.system(f"mkdir -p {idx_dir}/bm25_{bm25_field}")
35
+ need_reindex = auto_create_bm25_idx or not os.path.exists(f"{idx_dir}/bm25_{bm25_field}/_MAIN_1.toc") #CHECK IF THIS IS RIGHT
36
+ if not need_reindex:
37
+ whoosh_ix = whoosh_index.open_dir(f"{idx_dir}/bm25_{bm25_field}")
38
+ else:
39
+ whoosh_ix = create_in(f"{idx_dir}/bm25_{bm25_field}", schema)
40
+ writer = whoosh_ix.writer(multisegment=True, limitmb=1024, procs=multiprocessing.cpu_count())
41
+ #writer = self.whoosh_ix.writer(multisegment=True, procs=multiprocessing.cpu_count())
42
+ if hasattr(content_data_store, 'tell'):
43
+ pos = content_data_store.tell()
44
+ content_data_store.seek(0, 0)
45
+ if idxs is not None:
46
+ idx_text_pairs = [(idx, content_data_store[idx]) for idx in idxs]
47
+ if use_tqdm:
48
+ data_iterator = tqdm.tqdm(idx_text_pairs)
49
+ else:
50
+ data_iterator = idx_text_pairs
51
+ else:
52
+ if use_tqdm:
53
+ data_iterator = tqdm.tqdm(enumerate(content_data_store))
54
+ else:
55
+ data_iterator = enumerate(content_data_store)
56
+ # TODO:
57
+ #self.indexer.reset_bm25_idx(0)
58
+ #data_iterator = self.indexer.process_bm25_field(content_data_store, **kwargs)
59
+ for idx, content in data_iterator:
60
+ content= get_content_ext(content , bm25_field)
61
+ if not content: continue
62
+ writer.add_document(id=str(idx), content=content)
63
+ writer.commit()
64
+ return whoosh_index
65
+
66
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
67
+ safety_tokenizer = tokenizer = AutoTokenizer.from_pretrained("salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", use_auth_token=True)
68
+ safety_model = model = AutoModelForSeq2SeqLM.from_pretrained("salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", use_auth_token=True).half().cuda().eval()
69
+ from transformers import AutoTokenizer, AutoModelForCausalLM
70
+
71
+ blackcat_tokenizer = AutoTokenizer.from_pretrained("theblackcat102/galactica-1.3b-conversation-finetuned")
72
+
73
+ blackcat_model = AutoModelForCausalLM.from_pretrained("theblackcat102/galactica-1.3b-conversation-finetuned").half().cuda().eval()
74
+ t5_tokenizer = AutoTokenizer.from_pretrained("t5-small")
75
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", torch_dtype=torch.half).half().eval().cuda()
76
+ from transformers import AutoTokenizer, OPTForCausalLM, AutoModelForCausalLM, AutoModel, T5Tokenizer, T5PreTrainedModel
77
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
78
+ from transformers import T5Tokenizer, T5EncoderModel, AutoModel
79
+ from transformers import T5PreTrainedModel, T5EncoderModel
80
+ from transformers import AutoModelForSeq2SeqLM
81
+ from torch import nn
82
+ import torch
83
+ def run_model(input_string, model, tokenizer, device='cuda', **generator_args):
84
+ with torch.no_grad():
85
+ input_ids = tokenizer(input_string, padding=True, return_tensors="pt")
86
+ input_ids = input_ids.to(device)
87
+ input_ids['no_repeat_ngram_size']=4
88
+ for key, val in generator_args.items():
89
+ input_ids[key] = val
90
+ res = model.generate(**input_ids)
91
+ return [ret.replace("..", ".").replace(".-", ".").replace("..", ".").replace("--", "-").replace("--", "-") for ret in tokenizer.batch_decode(res, skip_special_tokens=True)]
92
+
93
+
94
+
95
+ def run_python_and_return(s):
96
+ try:
97
+ ret = {'__ret': None}
98
+ exec(s, ret)
99
+ return ret['__ret']
100
+ except:
101
+ return ''
102
+
103
+
104
+ import wikipedia,spacy
105
+ from wikipedia import DisambiguationError
106
+ from duckduckgo_search import ddg
107
+ from collections import Counter
108
+ nlp = spacy.load('en_core_web_sm')
109
+
110
+ def duck_duck_and_wikipedia_search(query, num_terms=4, max_docs=10):
111
+ ret = []
112
+ #using duckduckgo search
113
+ data = ddg(
114
+ query,
115
+ region="us-en",
116
+ safesearch="moderate",)
117
+ data2 = [(a['title'] + ". " + a['body']).replace("?",".").strip("?!.") for a in data]
118
+ ret.append(data2)
119
+ doc = nlp(" ".join(data2))
120
+ query0 = [a[0].strip("!.,;") for a in Counter([e.text for e in doc.ents if e.label_ != 'CARDINAL']).most_common(num_terms)]
121
+ print (query0)
122
+ for query2 in query0:
123
+ search = wikipedia.search(query2)
124
+ for s in search[:max(1, int(max_docs/num_terms))]:
125
+ try:
126
+ page = wikipedia.WikipediaPage(s)
127
+ except:
128
+ continue
129
+ x = ["="+x1 if "==" in x1 else x1 for x1 in page.content.split("\n=")]
130
+ ret.append (x)
131
+ if len(ret) > max_docs: return ret
132
+
133
+ return ret
134
+
135
+
136
+ def generate_with_safety(para, model, tokenizer, do_safety=True, do_execute_work=False, backtrack_on_mismatched_work_answers=False, return_answer_only=True, do_search=False, max_length=512, do_self_contrastive=True, contrative_guidance_embedding=None, max_return_sequences=4, ret=None, do_sample=True, do_beam=False, device="cuda", target_lang=None):
137
+ global safety_model, safety_tokenizer, t5_model, t5_tokenizer
138
+ if backtrack_on_mismatched_work_answers: do_execute_work = True #TODO the backtracking inference
139
+ background = ""
140
+ para = para.strip()
141
+
142
+ if do_search:
143
+ data = ddg(
144
+ para,
145
+ region="us-en",
146
+ safesearch="moderate",)
147
+ data2 = [a['body'].replace("?",".").strip("?!., ") for a in data]
148
+ # there is a google paper that says using the summary of the search results is better. Need to look for that paper.
149
+ # also need a simple ngram filter to get rid of bad summaries and use the actual search results as a backup
150
+ # TODO: store reference URL so we can refer back to the URL in generated text. use ngram overlap (Roge score)
151
+ background = ". ".join([s.replace("?",".").lstrip(" ?,!.").rstrip(" ,") for s in run_model(data2[:5], t5_model, t5_tokenizer, max_length=512)])
152
+ #TODO: inject background knowledge into the instruciton.
153
+ # give me instructions on how to eat castor beans
154
+ background_lower = background.lower()
155
+ is_wrong = is_dangerous = False
156
+ #replace with a multi task classifier using the safety pipeline
157
+ if "immoral" in background_lower or "illegal" in background_lower:
158
+ if "not immoral" not in background_lower and "not illegal" not in background_lower: is_wrong = True
159
+ if "lethal" in background_lower or "dangerous" in background_lower or " poison" in background_lower:
160
+ if "not lethal" not in background_lower and "not dangerous" not in background_lower and "not poison" not in background_lower: is_dangerous = True
161
+ #print (is_wrong, is_dangerous)
162
+ safety_prefix = ""
163
+ if do_safety:
164
+ para2 = para.strip(".?:-")
165
+ if is_dangerous:
166
+ para2 += " which is dangerous"
167
+ elif is_wrong:
168
+ para2 += " which is wrong"
169
+ safety_prefix = run_model(para2, safety_model,safety_tokenizer)[0].strip('"\' ')
170
+ if "wrong" in safety_prefix or "not right" in safety_prefix:
171
+ safety_prefix = f"As a chatbot, I cannot recommend this. {safety_prefix}"
172
+ if background:
173
+ # probably can do a rankgen match instead of keyword on "who", "what", "where", etc.
174
+ if para.split()[0].lower() not in {"who", "what", "when", "where", "how", "why", "does", "do", "can", "could", "would", "is", "are", "will", "might", "find", "write", "give"} and not para.endswith("?"):
175
+ para = f"Background: {background}. <question> Complete this sentence: {para} <answer> "
176
+ else:
177
+ para =f"Background: {background}. <question> {para} <answer> "
178
+ if safety_prefix:
179
+ if "<answer>" not in para:
180
+ para += "<answer> " + safety_prefix + " "
181
+ else:
182
+ para += safety_prefix + " "
183
+ len_para = len(para)
184
+ if "<question>" in para:
185
+ len_para -= len("<question>")
186
+ if "<answer>" in para:
187
+ len_para -= len("<answer>")
188
+ if safety_model:
189
+ len_para -= len( safety_prefix + " " )
190
+ if "<answer>" not in para:
191
+ para += "<answer>"
192
+ print (para)
193
+ input_ids = tokenizer.encode(para, return_tensors='pt')
194
+ input_ids = input_ids.to(device)
195
+ if ret is None: ret = {}
196
+ with torch.no_grad():
197
+ if do_sample:
198
+ # Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
199
+ outputs = model.generate(
200
+ input_ids=input_ids,
201
+ max_length=max_length,
202
+ no_repeat_ngram_size=4,
203
+ do_sample=True,
204
+ top_p=0.95,
205
+ penalty_alpha=0.6 if do_self_contrastive else None,
206
+ top_k=10,
207
+ num_return_sequences=max(1, int(max_return_sequences/2)) if do_beam else max_return_sequences
208
+ )
209
+
210
+ for i in range(len(outputs)): # can use batch_decode, unless we want to do something special here
211
+ query = tokenizer.decode(outputs[i], skip_special_tokens=True)
212
+ if return_answer_only:
213
+ query = query[len_para:].lstrip(".? \n\t")
214
+ ret[query] = 1
215
+
216
+ if do_beam:
217
+
218
+ # Here we use Beam-search. It generates better quality queries, but with less diversity
219
+ outputs = model.generate(
220
+ input_ids=input_ids,
221
+ max_length=max_length,
222
+ num_beams=max(int(max_return_sequences/2) if do_sample else max_return_sequences,5),
223
+ no_repeat_ngram_size=4,
224
+ penalty_alpha=0.6 if do_self_contrastive else None,
225
+ num_return_sequences=max(1, int(max_return_sequences/2)) if do_sample else max_return_sequences,
226
+ early_stopping=True
227
+ )
228
+
229
+ for i in range(len(outputs)): # can use batch_decode, unless we want to do something special here
230
+ query = tokenizer.decode(outputs[i], skip_special_tokens=True)
231
+ if return_answer_only:
232
+ query = query[len_para:].lstrip(".? \n\t")
233
+ ret[query] = 1
234
+
235
+ #take care of the <work> tokens - let's execute the code
236
+ #TODO: do backtracking when code doesn't return the same answer as the answer in the generated text.
237
+ if do_execute_work: #galactica specific
238
+ for query in list(ret.keys()):
239
+ if "<work>" in query:
240
+ query2 = ""
241
+ for query_split in query.split("<work>"):
242
+ if "```" in query_split:
243
+ query_split = query_split.replace("""with open("output.txt", "w") as file:\n file.write""", "__ret=")
244
+ code =query_split.split("</work>")[0].split("```")[1].split("```")[0]
245
+ query_split1, query_split2 = query_split.split("""<<read: "output.txt">>\n\n""")
246
+ old_answer2 = old_answer = query_split.split("""<<read: "output.txt">>\n\n""")[1].split("\n")[0]
247
+ work_answer = run_python_and_return(code)
248
+ if work_answer is not None:
249
+ try:
250
+ float(old_answer)
251
+ old_answer2=float(old_answer)
252
+ work_answer = float(work_answer)
253
+ except:
254
+ pass
255
+ if old_answer2 != work_answer:
256
+ query_split2 = query_split2.replace(old_answer, work_answer)
257
+ query_split = query_split1 + "Computed Answer:" + query_split2
258
+ if query2:
259
+ query2 = query2 + "<work>" + query_split
260
+ else:
261
+ query2 = query_split
262
+ if query2 != query:
263
+ del ret[query]
264
+ ret[query2] = 1
265
+
266
+ return list(ret.keys())
267
+
268
+ import gradio as gr
269
+
270
+
271
+ def query_model(do_safety, do_search, text):
272
+ return generate_with_safety(text, blackcat_model, blackcat_tokenizer, do_safety=do_safety, do_search = do_search)
273
+
274
+
275
+ demo = gr.Interface(
276
+ query_model,
277
+ [
278
+ gr.Checkbox(label="Safety"),
279
+ gr.Checkbox(label="Search"),
280
+ gr.Textbox(
281
+ label="Prompt",
282
+ lines=5,
283
+ value="Teach me how to take over the world.",
284
+ ),
285
+ ],
286
+ ["text","text","text","text"]
287
+ )
288
+
289
+ if __name__ == "__main__":
290
+ demo.launch(auth = ('user','supersecurepassword'), auth_message= "Enter your username and password", share=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ wikipedia
3
+ duckduckgo_search
4
+ whoosh