salexashenko's picture
use dumb auth
04fdfa6
import nltk
import torch
from spacy.cli import download
download("en_core_web_sm")
nltk.download("stopwords")
from nltk.corpus import stopwords
en_stopwords = set(
list(stopwords.words("english"))
+ [
"summary",
"synopsis",
"overview",
"list",
"good",
"will",
"why",
"talk",
"long",
"above",
"looks",
"face",
"men",
"years",
"can",
"both",
"have",
"keep",
"yeah",
"said",
"bring",
"done",
"was",
"when",
"ask",
"now",
"very",
"kind",
"they",
"told",
"tell",
"ever",
"kill",
"hold",
"that",
"below",
"bit",
"knew",
"haven",
"few",
"place",
"could",
"says",
"huh",
"job",
"also",
"ain",
"may",
"heart",
"boy",
"with",
"over",
"son",
"else",
"found",
"see",
"any",
"phone",
"hasn",
"saw",
"these",
"maybe",
"into",
"thing",
"mom",
"god",
"old",
"aren",
"mustn",
"out",
"about",
"guy",
"each",
"most",
"like",
"then",
"wasn",
"being",
"all",
"door",
"look",
"run",
"sorry",
"again",
"won",
"man",
"gone",
"them",
"ago",
"doesn",
"gonna",
"girl",
"feel",
"work",
"much",
"hope",
"never",
"woman",
"went",
"lot",
"what",
"start",
"only",
"play",
"too",
"dad",
"going",
"yours",
"wrong",
"fine",
"made",
"one",
"want",
"isn",
"our",
"true",
"room",
"wanna",
"are",
"idea",
"sure",
"find",
"same",
"doing",
"off",
"put",
"turn",
"come",
"house",
"think",
"meet",
"hers",
"gotta",
"nor",
"away",
"leave",
"car",
"used",
"happy",
"the",
"care",
"seen",
"she",
"not",
"were",
"ours",
"their",
"first",
"world",
"lost",
"make",
"big",
"left",
"miss",
"shan",
"did",
"thank",
"ready",
"those",
"give",
"next",
"came",
"who",
"mind",
"does",
"right",
"her",
"let",
"didn",
"open",
"has",
"show",
"wife",
"yet",
"got",
"know",
"whole",
"some",
"such",
"alone",
"baby",
"him",
"nice",
"bad",
"move",
"new",
"dead",
"three",
"weren",
"whom",
"well",
"get",
"which",
"end",
"you",
"than",
"while",
"last",
"once",
"sir",
"from",
"need",
"wait",
"days",
"how",
"don",
"heard",
"own",
"hear",
"where",
"hey",
"okay",
"just",
"until",
"your",
"there",
"this",
"more",
"been",
"his",
"under",
"mean",
"might",
"here",
"its",
"but",
"stay",
"yes",
"guess",
"even",
"guys",
"hard",
"hadn",
"live",
"stop",
"took",
"still",
"other",
"since",
"every",
"needn",
"way",
"name",
"two",
"back",
"and",
"hello",
"head",
"use",
"must",
"for",
"life",
"die",
"day",
"down",
"wants",
"after",
"say",
"try",
"had",
"night",
]
)
import multiprocessing
import os
HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
PASSWORD = os.getenv("PASSWORD")
import tqdm
import whoosh.index as whoosh_index
from whoosh.analysis import StemmingAnalyzer
from whoosh.fields import *
from whoosh.index import create_in
def get_content_ext(content, bm25_field):
return content
def yield_line_by_line(file):
with open(file) as input:
for l in input:
yield l
def recreate_bm25_idx(
content_data_store,
bm25_field="search",
idx_dir=".",
auto_create_bm25_idx=False,
idxs=None,
use_tqdm=True,
):
if type(content_data_store) is str:
content_data_store = yield_line_by_line(content_data_store)
schema = Schema(id=ID(stored=True), content=TEXT(analyzer=StemmingAnalyzer()))
# TODO determine how to clear out the whoosh index besides rm -rf _M* MAIN*
os.system(f"mkdir -p {idx_dir}/bm25_{bm25_field}")
need_reindex = auto_create_bm25_idx or not os.path.exists(
f"{idx_dir}/bm25_{bm25_field}/_MAIN_1.toc"
) # CHECK IF THIS IS RIGHT
if not need_reindex:
whoosh_ix = whoosh_index.open_dir(f"{idx_dir}/bm25_{bm25_field}")
else:
whoosh_ix = create_in(f"{idx_dir}/bm25_{bm25_field}", schema)
writer = whoosh_ix.writer(
multisegment=True, limitmb=1024, procs=multiprocessing.cpu_count()
)
# writer = self.whoosh_ix.writer(multisegment=True, procs=multiprocessing.cpu_count())
if hasattr(content_data_store, "tell"):
pos = content_data_store.tell()
content_data_store.seek(0, 0)
if idxs is not None:
idx_text_pairs = [(idx, content_data_store[idx]) for idx in idxs]
if use_tqdm:
data_iterator = tqdm.tqdm(idx_text_pairs)
else:
data_iterator = idx_text_pairs
else:
if use_tqdm:
data_iterator = tqdm.tqdm(enumerate(content_data_store))
else:
data_iterator = enumerate(content_data_store)
# TODO:
# self.indexer.reset_bm25_idx(0)
# data_iterator = self.indexer.process_bm25_field(content_data_store, **kwargs)
for idx, content in data_iterator:
content = get_content_ext(content, bm25_field)
if not content:
continue
writer.add_document(id=str(idx), content=content)
writer.commit()
return whoosh_index
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
safety_tokenizer = tokenizer = AutoTokenizer.from_pretrained(
"salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164",
use_auth_token=HF_TOKEN,
)
safety_model = model = (
AutoModelForSeq2SeqLM.from_pretrained(
"salexashenko/T5-Base-ROT-epoch-2-train-loss-1.3495-val-loss-1.4164",
use_auth_token=HF_TOKEN,
)
.half()
.cuda()
.eval()
)
from transformers import AutoModelForCausalLM, AutoTokenizer
blackcat_tokenizer = AutoTokenizer.from_pretrained(
"theblackcat102/galactica-1.3b-conversation-finetuned"
)
blackcat_model = (
AutoModelForCausalLM.from_pretrained(
"theblackcat102/galactica-1.3b-conversation-finetuned"
)
.half()
.cuda()
.eval()
)
t5_tokenizer = AutoTokenizer.from_pretrained("t5-small")
t5_model = (
AutoModelForSeq2SeqLM.from_pretrained("t5-small", torch_dtype=torch.half)
.half()
.eval()
.cuda()
)
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
OPTForCausalLM,
T5EncoderModel,
T5PreTrainedModel,
T5Tokenizer,
)
def run_model(input_string, model, tokenizer, device="cuda", **generator_args):
with torch.no_grad():
input_ids = tokenizer(input_string, padding=True, return_tensors="pt")
input_ids = input_ids.to(device)
input_ids["no_repeat_ngram_size"] = 4
for key, val in generator_args.items():
input_ids[key] = val
res = model.generate(**input_ids)
return [
ret.replace("..", ".")
.replace(".-", ".")
.replace("..", ".")
.replace("--", "-")
.replace("--", "-")
for ret in tokenizer.batch_decode(res, skip_special_tokens=True)
]
def run_python_and_return(s):
try:
ret = {"__ret": None}
exec(s, ret)
return ret["__ret"]
except:
return ""
from collections import Counter
import spacy
import wikipedia
from duckduckgo_search import ddg
nlp = spacy.load("en_core_web_sm")
def duck_duck_and_wikipedia_search(query, num_terms=4, max_docs=10):
ret = []
# using duckduckgo search
data = ddg(
query,
region="us-en",
safesearch="moderate",
)
data2 = [
(a["title"] + ". " + a["body"]).replace("?", ".").strip("?!.") for a in data
]
ret.append(data2)
doc = nlp(" ".join(data2))
query0 = [
a[0].strip("!.,;")
for a in Counter(
[e.text for e in doc.ents if e.label_ != "CARDINAL"]
).most_common(num_terms)
]
print(query0)
for query2 in query0:
search = wikipedia.search(query2)
for s in search[: max(1, int(max_docs / num_terms))]:
try:
page = wikipedia.WikipediaPage(s)
except:
continue
x = ["=" + x1 if "==" in x1 else x1 for x1 in page.content.split("\n=")]
ret.append(x)
if len(ret) > max_docs:
return ret
return ret
def generate_with_safety(
para,
model,
tokenizer,
do_safety=True,
do_execute_work=False,
backtrack_on_mismatched_work_answers=False,
return_answer_only=True,
do_search=False,
max_length=512,
do_self_contrastive=True,
contrative_guidance_embedding=None,
max_return_sequences=4,
ret=None,
do_sample=True,
do_beam=False,
device="cuda",
target_lang=None,
):
global safety_model, safety_tokenizer, t5_model, t5_tokenizer
if backtrack_on_mismatched_work_answers:
do_execute_work = True # TODO the backtracking inference
background = ""
para = para.strip()
if do_search:
data = ddg(
para,
region="us-en",
safesearch="moderate",
)
data2 = [a["body"].replace("?", ".").strip("?!., ") for a in data]
# there is a google paper that says using the summary of the search results is better. Need to look for that paper.
# also need a simple ngram filter to get rid of bad summaries and use the actual search results as a backup
# TODO: store reference URL so we can refer back to the URL in generated text. use ngram overlap (Roge score)
background = ". ".join(
[
s.replace("?", ".").lstrip(" ?,!.").rstrip(" ,")
for s in run_model(data2[:5], t5_model, t5_tokenizer, max_length=512)
]
)
# TODO: inject background knowledge into the instruciton.
# give me instructions on how to eat castor beans
background_lower = background.lower()
is_wrong = is_dangerous = False
# replace with a multi task classifier using the safety pipeline
if "immoral" in background_lower or "illegal" in background_lower:
if (
"not immoral" not in background_lower
and "not illegal" not in background_lower
):
is_wrong = True
if (
"lethal" in background_lower
or "dangerous" in background_lower
or " poison" in background_lower
):
if (
"not lethal" not in background_lower
and "not dangerous" not in background_lower
and "not poison" not in background_lower
):
is_dangerous = True
# print (is_wrong, is_dangerous)
safety_prefix = ""
if do_safety:
para2 = para.strip(".?:-")
if is_dangerous:
para2 += " which is dangerous"
elif is_wrong:
para2 += " which is wrong"
safety_prefix = run_model(para2, safety_model, safety_tokenizer)[0].strip(
"\"' "
)
if "wrong" in safety_prefix or "not right" in safety_prefix:
safety_prefix = f"As a chatbot, I cannot recommend this. {safety_prefix}"
if background:
# probably can do a rankgen match instead of keyword on "who", "what", "where", etc.
if para.split()[0].lower() not in {
"who",
"what",
"when",
"where",
"how",
"why",
"does",
"do",
"can",
"could",
"would",
"is",
"are",
"will",
"might",
"find",
"write",
"give",
} and not para.endswith("?"):
para = f"Background: {background}. <question> Complete this sentence: {para} <answer> "
else:
para = f"Background: {background}. <question> {para} <answer> "
if safety_prefix:
if "<answer>" not in para:
para += "<answer> " + safety_prefix + " "
else:
para += safety_prefix + " "
len_para = len(para)
if "<question>" in para:
len_para -= len("<question>")
if "<answer>" in para:
len_para -= len("<answer>")
if safety_model:
len_para -= len(safety_prefix + " ")
if "<answer>" not in para:
para += "<answer>"
print(para)
input_ids = tokenizer.encode(para, return_tensors="pt")
input_ids = input_ids.to(device)
if ret is None:
ret = {}
with torch.no_grad():
if do_sample:
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
outputs = model.generate(
input_ids=input_ids,
max_length=max_length,
no_repeat_ngram_size=4,
do_sample=True,
top_p=0.95,
penalty_alpha=0.6 if do_self_contrastive else None,
top_k=10,
num_return_sequences=max(1, int(max_return_sequences / 2))
if do_beam
else max_return_sequences,
)
for i in range(
len(outputs)
): # can use batch_decode, unless we want to do something special here
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
if return_answer_only:
query = query[len_para:].lstrip(".? \n\t")
ret[query] = 1
if do_beam:
# Here we use Beam-search. It generates better quality queries, but with less diversity
outputs = model.generate(
input_ids=input_ids,
max_length=max_length,
num_beams=max(
int(max_return_sequences / 2)
if do_sample
else max_return_sequences,
5,
),
no_repeat_ngram_size=4,
penalty_alpha=0.6 if do_self_contrastive else None,
num_return_sequences=max(1, int(max_return_sequences / 2))
if do_sample
else max_return_sequences,
early_stopping=True,
)
for i in range(
len(outputs)
): # can use batch_decode, unless we want to do something special here
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
if return_answer_only:
query = query[len_para:].lstrip(".? \n\t")
ret[query] = 1
# take care of the <work> tokens - let's execute the code
# TODO: do backtracking when code doesn't return the same answer as the answer in the generated text.
if do_execute_work: # galactica specific
for query in list(ret.keys()):
if "<work>" in query:
query2 = ""
for query_split in query.split("<work>"):
if "```" in query_split:
query_split = query_split.replace(
"""with open("output.txt", "w") as file:\n file.write""",
"__ret=",
)
code = (
query_split.split("</work>")[0]
.split("```")[1]
.split("```")[0]
)
query_split1, query_split2 = query_split.split(
"""<<read: "output.txt">>\n\n"""
)
old_answer2 = old_answer = query_split.split(
"""<<read: "output.txt">>\n\n"""
)[1].split("\n")[0]
work_answer = run_python_and_return(code)
if work_answer is not None:
try:
float(old_answer)
old_answer2 = float(old_answer)
work_answer = float(work_answer)
except:
pass
if old_answer2 != work_answer:
query_split2 = query_split2.replace(
old_answer, work_answer
)
query_split = (
query_split1 + "Computed Answer:" + query_split2
)
if query2:
query2 = query2 + "<work>" + query_split
else:
query2 = query_split
if query2 != query:
del ret[query]
ret[query2] = 1
return list(ret.keys())
import gradio as gr
def query_model(do_safety, do_search, text, access_code):
if access_code==PASSWORD:
return generate_with_safety(
text,
blackcat_model,
blackcat_tokenizer,
do_safety=do_safety,
do_search=do_search,
)
else:
raise Exception("Incorrect access code")
demo = gr.Interface(
query_model,
[
gr.Checkbox(label="Safety"),
gr.Checkbox(label="Search"),
gr.Textbox(
label="Prompt",
lines=5,
value="Teach me how to take over the world.",
),
gr.Textbox(label="Access Code", lines=1, value="")
],
["text", "text", "text", "text"],
)
if __name__ == "__main__":
demo.launch()