import nltk import torch from spacy.cli import download download("en_core_web_sm") nltk.download("stopwords") from nltk.corpus import stopwords en_stopwords = set( list(stopwords.words("english")) + [ "summary", "synopsis", "overview", "list", "good", "will", "why", "talk", "long", "above", "looks", "face", "men", "years", "can", "both", "have", "keep", "yeah", "said", "bring", "done", "was", "when", "ask", "now", "very", "kind", "they", "told", "tell", "ever", "kill", "hold", "that", "below", "bit", "knew", "haven", "few", "place", "could", "says", "huh", "job", "also", "ain", "may", "heart", "boy", "with", "over", "son", "else", "found", "see", "any", "phone", "hasn", "saw", "these", "maybe", "into", "thing", "mom", "god", "old", "aren", "mustn", "out", "about", "guy", "each", "most", "like", "then", "wasn", "being", "all", "door", "look", "run", "sorry", "again", "won", "man", "gone", "them", "ago", "doesn", "gonna", "girl", "feel", "work", "much", "hope", "never", "woman", "went", "lot", "what", "start", "only", "play", "too", "dad", "going", "yours", "wrong", "fine", "made", "one", "want", "isn", "our", "true", "room", "wanna", "are", "idea", "sure", "find", "same", "doing", "off", "put", "turn", "come", "house", "think", "meet", "hers", "gotta", "nor", "away", "leave", "car", "used", "happy", "the", "care", "seen", "she", "not", "were", "ours", "their", "first", "world", "lost", "make", "big", "left", "miss", "shan", "did", "thank", "ready", "those", "give", "next", "came", "who", "mind", "does", "right", "her", "let", "didn", "open", "has", "show", "wife", "yet", "got", "know", "whole", "some", "such", "alone", "baby", "him", "nice", "bad", "move", "new", "dead", "three", "weren", "whom", "well", "get", "which", "end", "you", "than", "while", "last", "once", "sir", "from", "need", "wait", "days", "how", "don", "heard", "own", "hear", "where", "hey", "okay", "just", "until", "your", "there", "this", "more", "been", "his", "under", "mean", "might", "here", "its", "but", "stay", "yes", "guess", "even", "guys", "hard", "hadn", "live", "stop", "took", "still", "other", "since", "every", "needn", "way", "name", "two", "back", "and", "hello", "head", "use", "must", "for", "life", "die", "day", "down", "wants", "after", "say", "try", "had", "night", ] ) import multiprocessing import os HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") PASSWORD = os.getenv("PASSWORD") import tqdm import whoosh.index as whoosh_index from whoosh.analysis import StemmingAnalyzer from whoosh.fields import * from whoosh.index import create_in def get_content_ext(content, bm25_field): return content def yield_line_by_line(file): with open(file) as input: for l in input: yield l def recreate_bm25_idx( content_data_store, bm25_field="search", idx_dir=".", auto_create_bm25_idx=False, idxs=None, use_tqdm=True, ): if type(content_data_store) is str: content_data_store = yield_line_by_line(content_data_store) schema = Schema(id=ID(stored=True), content=TEXT(analyzer=StemmingAnalyzer())) # TODO determine how to clear out the whoosh index besides rm -rf _M* MAIN* os.system(f"mkdir -p {idx_dir}/bm25_{bm25_field}") need_reindex = auto_create_bm25_idx or not os.path.exists( f"{idx_dir}/bm25_{bm25_field}/_MAIN_1.toc" ) # CHECK IF THIS IS RIGHT if not need_reindex: whoosh_ix = whoosh_index.open_dir(f"{idx_dir}/bm25_{bm25_field}") else: whoosh_ix = create_in(f"{idx_dir}/bm25_{bm25_field}", schema) writer = whoosh_ix.writer( multisegment=True, limitmb=1024, procs=multiprocessing.cpu_count() ) # writer = self.whoosh_ix.writer(multisegment=True, procs=multiprocessing.cpu_count()) if hasattr(content_data_store, "tell"): pos = content_data_store.tell() content_data_store.seek(0, 0) if idxs is not None: idx_text_pairs = [(idx, content_data_store[idx]) for idx in idxs] if use_tqdm: data_iterator = tqdm.tqdm(idx_text_pairs) else: data_iterator = idx_text_pairs else: if use_tqdm: data_iterator = tqdm.tqdm(enumerate(content_data_store)) else: data_iterator = enumerate(content_data_store) # TODO: # self.indexer.reset_bm25_idx(0) # data_iterator = self.indexer.process_bm25_field(content_data_store, **kwargs) for idx, content in data_iterator: content = get_content_ext(content, bm25_field) if not content: continue writer.add_document(id=str(idx), content=content) writer.commit() return whoosh_index from transformers import AutoModelForSeq2SeqLM, AutoTokenizer safety_tokenizer = tokenizer = AutoTokenizer.from_pretrained( "salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", use_auth_token=HF_TOKEN, ) safety_model = model = ( AutoModelForSeq2SeqLM.from_pretrained( "salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164", use_auth_token=HF_TOKEN, ) .half() .cuda() .eval() ) from transformers import AutoModelForCausalLM, AutoTokenizer blackcat_tokenizer = AutoTokenizer.from_pretrained( "theblackcat102/galactica-1.3b-conversation-finetuned" ) blackcat_model = ( AutoModelForCausalLM.from_pretrained( "theblackcat102/galactica-1.3b-conversation-finetuned" ) .half() .cuda() .eval() ) t5_tokenizer = AutoTokenizer.from_pretrained("t5-small") t5_model = ( AutoModelForSeq2SeqLM.from_pretrained("t5-small", torch_dtype=torch.half) .half() .eval() .cuda() ) from transformers import ( AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, OPTForCausalLM, T5EncoderModel, T5PreTrainedModel, T5Tokenizer, ) def run_model(input_string, model, tokenizer, device="cuda", **generator_args): with torch.no_grad(): input_ids = tokenizer(input_string, padding=True, return_tensors="pt") input_ids = input_ids.to(device) input_ids["no_repeat_ngram_size"] = 4 for key, val in generator_args.items(): input_ids[key] = val res = model.generate(**input_ids) return [ ret.replace("..", ".") .replace(".-", ".") .replace("..", ".") .replace("--", "-") .replace("--", "-") for ret in tokenizer.batch_decode(res, skip_special_tokens=True) ] def run_python_and_return(s): try: ret = {"__ret": None} exec(s, ret) return ret["__ret"] except: return "" from collections import Counter import spacy import wikipedia from duckduckgo_search import ddg nlp = spacy.load("en_core_web_sm") def duck_duck_and_wikipedia_search(query, num_terms=4, max_docs=10): ret = [] # using duckduckgo search data = ddg( query, region="us-en", safesearch="moderate", ) data2 = [ (a["title"] + ". " + a["body"]).replace("?", ".").strip("?!.") for a in data ] ret.append(data2) doc = nlp(" ".join(data2)) query0 = [ a[0].strip("!.,;") for a in Counter( [e.text for e in doc.ents if e.label_ != "CARDINAL"] ).most_common(num_terms) ] print(query0) for query2 in query0: search = wikipedia.search(query2) for s in search[: max(1, int(max_docs / num_terms))]: try: page = wikipedia.WikipediaPage(s) except: continue x = ["=" + x1 if "==" in x1 else x1 for x1 in page.content.split("\n=")] ret.append(x) if len(ret) > max_docs: return ret return ret def generate_with_safety( para, model, tokenizer, do_safety=True, do_execute_work=False, backtrack_on_mismatched_work_answers=False, return_answer_only=True, do_search=False, max_length=512, do_self_contrastive=True, contrative_guidance_embedding=None, max_return_sequences=4, ret=None, do_sample=True, do_beam=False, device="cuda", target_lang=None, ): global safety_model, safety_tokenizer, t5_model, t5_tokenizer if backtrack_on_mismatched_work_answers: do_execute_work = True # TODO the backtracking inference background = "" para = para.strip() if do_search: data = ddg( para, region="us-en", safesearch="moderate", ) data2 = [a["body"].replace("?", ".").strip("?!., ") for a in data] # there is a google paper that says using the summary of the search results is better. Need to look for that paper. # also need a simple ngram filter to get rid of bad summaries and use the actual search results as a backup # TODO: store reference URL so we can refer back to the URL in generated text. use ngram overlap (Roge score) background = ". ".join( [ s.replace("?", ".").lstrip(" ?,!.").rstrip(" ,") for s in run_model(data2[:5], t5_model, t5_tokenizer, max_length=512) ] ) # TODO: inject background knowledge into the instruciton. # give me instructions on how to eat castor beans background_lower = background.lower() is_wrong = is_dangerous = False # replace with a multi task classifier using the safety pipeline if "immoral" in background_lower or "illegal" in background_lower: if ( "not immoral" not in background_lower and "not illegal" not in background_lower ): is_wrong = True if ( "lethal" in background_lower or "dangerous" in background_lower or " poison" in background_lower ): if ( "not lethal" not in background_lower and "not dangerous" not in background_lower and "not poison" not in background_lower ): is_dangerous = True # print (is_wrong, is_dangerous) safety_prefix = "" if do_safety: para2 = para.strip(".?:-") if is_dangerous: para2 += " which is dangerous" elif is_wrong: para2 += " which is wrong" safety_prefix = run_model(para2, safety_model, safety_tokenizer)[0].strip( "\"' " ) if "wrong" in safety_prefix or "not right" in safety_prefix: safety_prefix = f"As a chatbot, I cannot recommend this. {safety_prefix}" if background: # probably can do a rankgen match instead of keyword on "who", "what", "where", etc. if para.split()[0].lower() not in { "who", "what", "when", "where", "how", "why", "does", "do", "can", "could", "would", "is", "are", "will", "might", "find", "write", "give", } and not para.endswith("?"): para = f"Background: {background}. Complete this sentence: {para} " else: para = f"Background: {background}. {para} " if safety_prefix: if "" not in para: para += " " + safety_prefix + " " else: para += safety_prefix + " " len_para = len(para) if "" in para: len_para -= len("") if "" in para: len_para -= len("") if safety_model: len_para -= len(safety_prefix + " ") if "" not in para: para += "" print(para) input_ids = tokenizer.encode(para, return_tensors="pt") input_ids = input_ids.to(device) if ret is None: ret = {} with torch.no_grad(): if do_sample: # Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality outputs = model.generate( input_ids=input_ids, max_length=max_length, no_repeat_ngram_size=4, do_sample=True, top_p=0.95, penalty_alpha=0.6 if do_self_contrastive else None, top_k=10, num_return_sequences=max(1, int(max_return_sequences / 2)) if do_beam else max_return_sequences, ) for i in range( len(outputs) ): # can use batch_decode, unless we want to do something special here query = tokenizer.decode(outputs[i], skip_special_tokens=True) if return_answer_only: query = query[len_para:].lstrip(".? \n\t") ret[query] = 1 if do_beam: # Here we use Beam-search. It generates better quality queries, but with less diversity outputs = model.generate( input_ids=input_ids, max_length=max_length, num_beams=max( int(max_return_sequences / 2) if do_sample else max_return_sequences, 5, ), no_repeat_ngram_size=4, penalty_alpha=0.6 if do_self_contrastive else None, num_return_sequences=max(1, int(max_return_sequences / 2)) if do_sample else max_return_sequences, early_stopping=True, ) for i in range( len(outputs) ): # can use batch_decode, unless we want to do something special here query = tokenizer.decode(outputs[i], skip_special_tokens=True) if return_answer_only: query = query[len_para:].lstrip(".? \n\t") ret[query] = 1 # take care of the tokens - let's execute the code # TODO: do backtracking when code doesn't return the same answer as the answer in the generated text. if do_execute_work: # galactica specific for query in list(ret.keys()): if "" in query: query2 = "" for query_split in query.split(""): if "```" in query_split: query_split = query_split.replace( """with open("output.txt", "w") as file:\n file.write""", "__ret=", ) code = ( query_split.split("")[0] .split("```")[1] .split("```")[0] ) query_split1, query_split2 = query_split.split( """<>\n\n""" ) old_answer2 = old_answer = query_split.split( """<>\n\n""" )[1].split("\n")[0] work_answer = run_python_and_return(code) if work_answer is not None: try: float(old_answer) old_answer2 = float(old_answer) work_answer = float(work_answer) except: pass if old_answer2 != work_answer: query_split2 = query_split2.replace( old_answer, work_answer ) query_split = ( query_split1 + "Computed Answer:" + query_split2 ) if query2: query2 = query2 + "" + query_split else: query2 = query_split if query2 != query: del ret[query] ret[query2] = 1 return list(ret.keys()) import gradio as gr def query_model(do_safety, do_search, text, access_code): if access_code==PASSWORD: return generate_with_safety( text, blackcat_model, blackcat_tokenizer, do_safety=do_safety, do_search=do_search, ) else: raise Exception("Incorrect access code") demo = gr.Interface( query_model, [ gr.Checkbox(label="Safety"), gr.Checkbox(label="Search"), gr.Textbox( label="Prompt", lines=5, value="Teach me how to take over the world.", ), gr.Textbox(label="Access Code", lines=1, value="") ], ["text", "text", "text", "text"], ) if __name__ == "__main__": demo.launch()