Spaces:
Runtime error
Runtime error
salexashenko
commited on
Commit
•
3a555af
1
Parent(s):
7cd979d
init
Browse files- app.py +290 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
|
3 |
+
from spacy.cli import download
|
4 |
+
download("en_core_web_sm")
|
5 |
+
nltk.download('stopwords')
|
6 |
+
|
7 |
+
from nltk.corpus import stopwords
|
8 |
+
|
9 |
+
en_stopwords = set(list(stopwords.words('english')) +
|
10 |
+
['summary', 'synopsis', 'overview', 'list', 'good', 'will', 'why', 'talk', 'long', 'above', 'looks', 'face', 'men', 'years', 'can', 'both', 'have', 'keep', 'yeah', 'said', 'bring', 'done', 'was', 'when', 'ask', 'now', 'very', 'kind', 'they', 'told', 'tell', 'ever', 'kill', 'hold', 'that', 'below', 'bit', 'knew', 'haven', 'few', 'place', 'could', 'says', 'huh', 'job', 'also', 'ain', 'may', 'heart', 'boy', 'with', 'over', 'son', 'else', 'found', 'see', 'any', 'phone', 'hasn', 'saw', 'these', 'maybe', 'into', 'thing', 'mom', 'god', 'old', 'aren', 'mustn', 'out', 'about', 'guy', 'each', 'most', 'like', 'then', 'wasn', 'being', 'all', 'door', 'look', 'run', 'sorry', 'again', 'won', 'man', 'gone', 'them', 'ago', 'doesn', 'gonna', 'girl', 'feel', 'work', 'much', 'hope', 'never', 'woman', 'went', 'lot', 'what', 'start', 'only', 'play', 'too', 'dad', 'going', 'yours', 'wrong', 'fine', 'made', 'one', 'want', 'isn', 'our', 'true', 'room', 'wanna', 'are', 'idea', 'sure', 'find', 'same', 'doing', 'off', 'put', 'turn', 'come', 'house', 'think', 'meet', 'hers', 'gotta', 'nor', 'away', 'leave', 'car', 'used', 'happy', 'the', 'care', 'seen', 'she', 'not', 'were', 'ours', 'their', 'first', 'world', 'lost', 'make', 'big', 'left', 'miss', 'shan', 'did', 'thank', 'ready', 'those', 'give', 'next', 'came', 'who', 'mind', 'does', 'right', 'her', 'let', 'didn', 'open', 'has', 'show', 'wife', 'yet', 'got', 'know', 'whole', 'some', 'such', 'alone', 'baby', 'him', 'nice', 'bad', 'move', 'new', 'dead', 'three', 'weren', 'whom', 'well', 'get', 'which', 'end', 'you', 'than', 'while', 'last', 'once', 'sir', 'from', 'need', 'wait', 'days', 'how', 'don', 'heard', 'own', 'hear', 'where', 'hey', 'okay', 'just', 'until', 'your', 'there', 'this', 'more', 'been', 'his', 'under', 'mean', 'might', 'here', 'its', 'but', 'stay', 'yes', 'guess', 'even', 'guys', 'hard', 'hadn', 'live', 'stop', 'took', 'still', 'other', 'since', 'every', 'needn', 'way', 'name', 'two', 'back', 'and', 'hello', 'head', 'use', 'must', 'for', 'life', 'die', 'day', 'down', 'wants', 'after', 'say', 'try', 'had', 'night']
|
11 |
+
)
|
12 |
+
|
13 |
+
import multiprocessing
|
14 |
+
import os
|
15 |
+
from whoosh.analysis import StemmingAnalyzer
|
16 |
+
from whoosh.index import create_in
|
17 |
+
from whoosh.fields import *
|
18 |
+
import whoosh.index as whoosh_index
|
19 |
+
import tqdm
|
20 |
+
|
21 |
+
def get_content_ext(content, bm25_field):
|
22 |
+
return content
|
23 |
+
|
24 |
+
def yield_line_by_line(file):
|
25 |
+
with open(file) as input:
|
26 |
+
for l in input:
|
27 |
+
yield l
|
28 |
+
|
29 |
+
def recreate_bm25_idx(content_data_store, bm25_field="search", idx_dir=".", auto_create_bm25_idx=False, idxs=None, use_tqdm=True):
|
30 |
+
if type(content_data_store) is str:
|
31 |
+
content_data_store = yield_line_by_line(content_data_store)
|
32 |
+
schema = Schema(id=ID(stored=True), content=TEXT(analyzer=StemmingAnalyzer()))
|
33 |
+
#TODO determine how to clear out the whoosh index besides rm -rf _M* MAIN*
|
34 |
+
os.system(f"mkdir -p {idx_dir}/bm25_{bm25_field}")
|
35 |
+
need_reindex = auto_create_bm25_idx or not os.path.exists(f"{idx_dir}/bm25_{bm25_field}/_MAIN_1.toc") #CHECK IF THIS IS RIGHT
|
36 |
+
if not need_reindex:
|
37 |
+
whoosh_ix = whoosh_index.open_dir(f"{idx_dir}/bm25_{bm25_field}")
|
38 |
+
else:
|
39 |
+
whoosh_ix = create_in(f"{idx_dir}/bm25_{bm25_field}", schema)
|
40 |
+
writer = whoosh_ix.writer(multisegment=True, limitmb=1024, procs=multiprocessing.cpu_count())
|
41 |
+
#writer = self.whoosh_ix.writer(multisegment=True, procs=multiprocessing.cpu_count())
|
42 |
+
if hasattr(content_data_store, 'tell'):
|
43 |
+
pos = content_data_store.tell()
|
44 |
+
content_data_store.seek(0, 0)
|
45 |
+
if idxs is not None:
|
46 |
+
idx_text_pairs = [(idx, content_data_store[idx]) for idx in idxs]
|
47 |
+
if use_tqdm:
|
48 |
+
data_iterator = tqdm.tqdm(idx_text_pairs)
|
49 |
+
else:
|
50 |
+
data_iterator = idx_text_pairs
|
51 |
+
else:
|
52 |
+
if use_tqdm:
|
53 |
+
data_iterator = tqdm.tqdm(enumerate(content_data_store))
|
54 |
+
else:
|
55 |
+
data_iterator = enumerate(content_data_store)
|
56 |
+
# TODO:
|
57 |
+
#self.indexer.reset_bm25_idx(0)
|
58 |
+
#data_iterator = self.indexer.process_bm25_field(content_data_store, **kwargs)
|
59 |
+
for idx, content in data_iterator:
|
60 |
+
content= get_content_ext(content , bm25_field)
|
61 |
+
if not content: continue
|
62 |
+
writer.add_document(id=str(idx), content=content)
|
63 |
+
writer.commit()
|
64 |
+
return whoosh_index
|
65 |
+
|
66 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
67 |
+
safety_tokenizer = tokenizer = AutoTokenizer.from_pretrained("salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", use_auth_token=True)
|
68 |
+
safety_model = model = AutoModelForSeq2SeqLM.from_pretrained("salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", use_auth_token=True).half().cuda().eval()
|
69 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
70 |
+
|
71 |
+
blackcat_tokenizer = AutoTokenizer.from_pretrained("theblackcat102/galactica-1.3b-conversation-finetuned")
|
72 |
+
|
73 |
+
blackcat_model = AutoModelForCausalLM.from_pretrained("theblackcat102/galactica-1.3b-conversation-finetuned").half().cuda().eval()
|
74 |
+
t5_tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
75 |
+
t5_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", torch_dtype=torch.half).half().eval().cuda()
|
76 |
+
from transformers import AutoTokenizer, OPTForCausalLM, AutoModelForCausalLM, AutoModel, T5Tokenizer, T5PreTrainedModel
|
77 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
78 |
+
from transformers import T5Tokenizer, T5EncoderModel, AutoModel
|
79 |
+
from transformers import T5PreTrainedModel, T5EncoderModel
|
80 |
+
from transformers import AutoModelForSeq2SeqLM
|
81 |
+
from torch import nn
|
82 |
+
import torch
|
83 |
+
def run_model(input_string, model, tokenizer, device='cuda', **generator_args):
|
84 |
+
with torch.no_grad():
|
85 |
+
input_ids = tokenizer(input_string, padding=True, return_tensors="pt")
|
86 |
+
input_ids = input_ids.to(device)
|
87 |
+
input_ids['no_repeat_ngram_size']=4
|
88 |
+
for key, val in generator_args.items():
|
89 |
+
input_ids[key] = val
|
90 |
+
res = model.generate(**input_ids)
|
91 |
+
return [ret.replace("..", ".").replace(".-", ".").replace("..", ".").replace("--", "-").replace("--", "-") for ret in tokenizer.batch_decode(res, skip_special_tokens=True)]
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
def run_python_and_return(s):
|
96 |
+
try:
|
97 |
+
ret = {'__ret': None}
|
98 |
+
exec(s, ret)
|
99 |
+
return ret['__ret']
|
100 |
+
except:
|
101 |
+
return ''
|
102 |
+
|
103 |
+
|
104 |
+
import wikipedia,spacy
|
105 |
+
from wikipedia import DisambiguationError
|
106 |
+
from duckduckgo_search import ddg
|
107 |
+
from collections import Counter
|
108 |
+
nlp = spacy.load('en_core_web_sm')
|
109 |
+
|
110 |
+
def duck_duck_and_wikipedia_search(query, num_terms=4, max_docs=10):
|
111 |
+
ret = []
|
112 |
+
#using duckduckgo search
|
113 |
+
data = ddg(
|
114 |
+
query,
|
115 |
+
region="us-en",
|
116 |
+
safesearch="moderate",)
|
117 |
+
data2 = [(a['title'] + ". " + a['body']).replace("?",".").strip("?!.") for a in data]
|
118 |
+
ret.append(data2)
|
119 |
+
doc = nlp(" ".join(data2))
|
120 |
+
query0 = [a[0].strip("!.,;") for a in Counter([e.text for e in doc.ents if e.label_ != 'CARDINAL']).most_common(num_terms)]
|
121 |
+
print (query0)
|
122 |
+
for query2 in query0:
|
123 |
+
search = wikipedia.search(query2)
|
124 |
+
for s in search[:max(1, int(max_docs/num_terms))]:
|
125 |
+
try:
|
126 |
+
page = wikipedia.WikipediaPage(s)
|
127 |
+
except:
|
128 |
+
continue
|
129 |
+
x = ["="+x1 if "==" in x1 else x1 for x1 in page.content.split("\n=")]
|
130 |
+
ret.append (x)
|
131 |
+
if len(ret) > max_docs: return ret
|
132 |
+
|
133 |
+
return ret
|
134 |
+
|
135 |
+
|
136 |
+
def generate_with_safety(para, model, tokenizer, do_safety=True, do_execute_work=False, backtrack_on_mismatched_work_answers=False, return_answer_only=True, do_search=False, max_length=512, do_self_contrastive=True, contrative_guidance_embedding=None, max_return_sequences=4, ret=None, do_sample=True, do_beam=False, device="cuda", target_lang=None):
|
137 |
+
global safety_model, safety_tokenizer, t5_model, t5_tokenizer
|
138 |
+
if backtrack_on_mismatched_work_answers: do_execute_work = True #TODO the backtracking inference
|
139 |
+
background = ""
|
140 |
+
para = para.strip()
|
141 |
+
|
142 |
+
if do_search:
|
143 |
+
data = ddg(
|
144 |
+
para,
|
145 |
+
region="us-en",
|
146 |
+
safesearch="moderate",)
|
147 |
+
data2 = [a['body'].replace("?",".").strip("?!., ") for a in data]
|
148 |
+
# there is a google paper that says using the summary of the search results is better. Need to look for that paper.
|
149 |
+
# also need a simple ngram filter to get rid of bad summaries and use the actual search results as a backup
|
150 |
+
# TODO: store reference URL so we can refer back to the URL in generated text. use ngram overlap (Roge score)
|
151 |
+
background = ". ".join([s.replace("?",".").lstrip(" ?,!.").rstrip(" ,") for s in run_model(data2[:5], t5_model, t5_tokenizer, max_length=512)])
|
152 |
+
#TODO: inject background knowledge into the instruciton.
|
153 |
+
# give me instructions on how to eat castor beans
|
154 |
+
background_lower = background.lower()
|
155 |
+
is_wrong = is_dangerous = False
|
156 |
+
#replace with a multi task classifier using the safety pipeline
|
157 |
+
if "immoral" in background_lower or "illegal" in background_lower:
|
158 |
+
if "not immoral" not in background_lower and "not illegal" not in background_lower: is_wrong = True
|
159 |
+
if "lethal" in background_lower or "dangerous" in background_lower or " poison" in background_lower:
|
160 |
+
if "not lethal" not in background_lower and "not dangerous" not in background_lower and "not poison" not in background_lower: is_dangerous = True
|
161 |
+
#print (is_wrong, is_dangerous)
|
162 |
+
safety_prefix = ""
|
163 |
+
if do_safety:
|
164 |
+
para2 = para.strip(".?:-")
|
165 |
+
if is_dangerous:
|
166 |
+
para2 += " which is dangerous"
|
167 |
+
elif is_wrong:
|
168 |
+
para2 += " which is wrong"
|
169 |
+
safety_prefix = run_model(para2, safety_model,safety_tokenizer)[0].strip('"\' ')
|
170 |
+
if "wrong" in safety_prefix or "not right" in safety_prefix:
|
171 |
+
safety_prefix = f"As a chatbot, I cannot recommend this. {safety_prefix}"
|
172 |
+
if background:
|
173 |
+
# probably can do a rankgen match instead of keyword on "who", "what", "where", etc.
|
174 |
+
if para.split()[0].lower() not in {"who", "what", "when", "where", "how", "why", "does", "do", "can", "could", "would", "is", "are", "will", "might", "find", "write", "give"} and not para.endswith("?"):
|
175 |
+
para = f"Background: {background}. <question> Complete this sentence: {para} <answer> "
|
176 |
+
else:
|
177 |
+
para =f"Background: {background}. <question> {para} <answer> "
|
178 |
+
if safety_prefix:
|
179 |
+
if "<answer>" not in para:
|
180 |
+
para += "<answer> " + safety_prefix + " "
|
181 |
+
else:
|
182 |
+
para += safety_prefix + " "
|
183 |
+
len_para = len(para)
|
184 |
+
if "<question>" in para:
|
185 |
+
len_para -= len("<question>")
|
186 |
+
if "<answer>" in para:
|
187 |
+
len_para -= len("<answer>")
|
188 |
+
if safety_model:
|
189 |
+
len_para -= len( safety_prefix + " " )
|
190 |
+
if "<answer>" not in para:
|
191 |
+
para += "<answer>"
|
192 |
+
print (para)
|
193 |
+
input_ids = tokenizer.encode(para, return_tensors='pt')
|
194 |
+
input_ids = input_ids.to(device)
|
195 |
+
if ret is None: ret = {}
|
196 |
+
with torch.no_grad():
|
197 |
+
if do_sample:
|
198 |
+
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
|
199 |
+
outputs = model.generate(
|
200 |
+
input_ids=input_ids,
|
201 |
+
max_length=max_length,
|
202 |
+
no_repeat_ngram_size=4,
|
203 |
+
do_sample=True,
|
204 |
+
top_p=0.95,
|
205 |
+
penalty_alpha=0.6 if do_self_contrastive else None,
|
206 |
+
top_k=10,
|
207 |
+
num_return_sequences=max(1, int(max_return_sequences/2)) if do_beam else max_return_sequences
|
208 |
+
)
|
209 |
+
|
210 |
+
for i in range(len(outputs)): # can use batch_decode, unless we want to do something special here
|
211 |
+
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
|
212 |
+
if return_answer_only:
|
213 |
+
query = query[len_para:].lstrip(".? \n\t")
|
214 |
+
ret[query] = 1
|
215 |
+
|
216 |
+
if do_beam:
|
217 |
+
|
218 |
+
# Here we use Beam-search. It generates better quality queries, but with less diversity
|
219 |
+
outputs = model.generate(
|
220 |
+
input_ids=input_ids,
|
221 |
+
max_length=max_length,
|
222 |
+
num_beams=max(int(max_return_sequences/2) if do_sample else max_return_sequences,5),
|
223 |
+
no_repeat_ngram_size=4,
|
224 |
+
penalty_alpha=0.6 if do_self_contrastive else None,
|
225 |
+
num_return_sequences=max(1, int(max_return_sequences/2)) if do_sample else max_return_sequences,
|
226 |
+
early_stopping=True
|
227 |
+
)
|
228 |
+
|
229 |
+
for i in range(len(outputs)): # can use batch_decode, unless we want to do something special here
|
230 |
+
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
|
231 |
+
if return_answer_only:
|
232 |
+
query = query[len_para:].lstrip(".? \n\t")
|
233 |
+
ret[query] = 1
|
234 |
+
|
235 |
+
#take care of the <work> tokens - let's execute the code
|
236 |
+
#TODO: do backtracking when code doesn't return the same answer as the answer in the generated text.
|
237 |
+
if do_execute_work: #galactica specific
|
238 |
+
for query in list(ret.keys()):
|
239 |
+
if "<work>" in query:
|
240 |
+
query2 = ""
|
241 |
+
for query_split in query.split("<work>"):
|
242 |
+
if "```" in query_split:
|
243 |
+
query_split = query_split.replace("""with open("output.txt", "w") as file:\n file.write""", "__ret=")
|
244 |
+
code =query_split.split("</work>")[0].split("```")[1].split("```")[0]
|
245 |
+
query_split1, query_split2 = query_split.split("""<<read: "output.txt">>\n\n""")
|
246 |
+
old_answer2 = old_answer = query_split.split("""<<read: "output.txt">>\n\n""")[1].split("\n")[0]
|
247 |
+
work_answer = run_python_and_return(code)
|
248 |
+
if work_answer is not None:
|
249 |
+
try:
|
250 |
+
float(old_answer)
|
251 |
+
old_answer2=float(old_answer)
|
252 |
+
work_answer = float(work_answer)
|
253 |
+
except:
|
254 |
+
pass
|
255 |
+
if old_answer2 != work_answer:
|
256 |
+
query_split2 = query_split2.replace(old_answer, work_answer)
|
257 |
+
query_split = query_split1 + "Computed Answer:" + query_split2
|
258 |
+
if query2:
|
259 |
+
query2 = query2 + "<work>" + query_split
|
260 |
+
else:
|
261 |
+
query2 = query_split
|
262 |
+
if query2 != query:
|
263 |
+
del ret[query]
|
264 |
+
ret[query2] = 1
|
265 |
+
|
266 |
+
return list(ret.keys())
|
267 |
+
|
268 |
+
import gradio as gr
|
269 |
+
|
270 |
+
|
271 |
+
def query_model(do_safety, do_search, text):
|
272 |
+
return generate_with_safety(text, blackcat_model, blackcat_tokenizer, do_safety=do_safety, do_search = do_search)
|
273 |
+
|
274 |
+
|
275 |
+
demo = gr.Interface(
|
276 |
+
query_model,
|
277 |
+
[
|
278 |
+
gr.Checkbox(label="Safety"),
|
279 |
+
gr.Checkbox(label="Search"),
|
280 |
+
gr.Textbox(
|
281 |
+
label="Prompt",
|
282 |
+
lines=5,
|
283 |
+
value="Teach me how to take over the world.",
|
284 |
+
),
|
285 |
+
],
|
286 |
+
["text","text","text","text"]
|
287 |
+
)
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
demo.launch(auth = ('user','supersecurepassword'), auth_message= "Enter your username and password", share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
wikipedia
|
3 |
+
duckduckgo_search
|
4 |
+
whoosh
|