AI-book / app.py
Omnibus's picture
Update app.py
2fd7c54 verified
raw
history blame
20.6 kB
from huggingface_hub import InferenceClient, HfApi, upload_file
import datetime
import gradio as gr
import random
import prompts
import json
import uuid
import os
token=os.environ.get("HF_TOKEN")
username="omnibus"
dataset_name="tmp"
api=HfApi(token="")
VERBOSE=False
history = []
hist_out= []
summary =[]
main_point=[]
summary.append("")
main_point.append("")
models=[
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.2",
"google/gemma-7b",
"google/gemma-7b-it",
"google/gemma-2b",
"google/gemma-2b-it",
"meta-llama/Llama-2-7b-chat-hf",
"codellama/CodeLlama-70b-Instruct-hf",
"openchat/openchat-3.5-0106",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
]
client_z=[]
def load_models(inp):
if VERBOSE==True:
print(type(inp))
print(inp)
print(models[inp])
client_z.clear()
client_z.append(InferenceClient(models[inp]))
#if "mistralai" in models[inp]:
# tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
return gr.update(label=models[inp])
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
agents =[
"COMMENTER",
"BLOG_POSTER",
"REPLY_TO_COMMENTER",
"COMPRESS_HISTORY_PROMPT"
]
temperature=0.9
max_new_tokens=256
max_new_tokens2=4000
top_p=0.95
repetition_penalty=1.0,
def compress_history(formatted_prompt):
print("###############\nRUNNING COMPRESS HISTORY\n###############\n")
seed = random.randint(1,1111111111111111)
agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])
system_prompt=agent
temperature = 0.9
if temperature < 1e-2:
temperature = 1e-2
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=1048,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
seed=seed,
)
#history.append((prompt,""))
#formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
formatted_prompt = formatted_prompt
client=client_z[0]
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
#history.append((output,history))
print(output)
print(main_point[0])
return output
def comment_generate(prompt, history,post_check,full_conv, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,):
#def question_generate(prompt, history):
print(post_check)
#full_conv=history
print(f'full_conv::\n{full_conv}')
print("###############\nRUNNING QUESTION GENERATOR\n###############\n")
seed = random.randint(1,1111111111111111)
agent=prompts.COMMENTER.format(focus=main_point[0])
system_prompt=agent
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=seed,
)
#history.append((prompt,""))
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
client=client_z[0]
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
history.append((output,None))
comment_cnt=post_check['comment']
print(type(comment_cnt))
post_check['comment']=comment_cnt+1
#out_json = {'user':"",'datetime':current_time,'title':title,'blog':1,'comment':0,'reply':0,"prompt":prompt,"output":output}
#full_conv[-1]+=(output,)
full_conv.append((None,output,None))
html_out=load_html(full_conv,None)
#history.append((output,history))
#[textbox, chatbot, textbox, json, json, html]
return "",history,post_check,post_check,post_check,html_out
def reply_generate(prompt, history,post_check,full_conv, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,):
#def question_generate(prompt, history):
print(post_check)
#full_conv=history
print(f'full_conv::\n{full_conv}')
print("###############\nRUNNING QUESTION GENERATOR\n###############\n")
seed = random.randint(1,1111111111111111)
agent=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0])
system_prompt=agent
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=seed,
)
#history.append((prompt,""))
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
client=client_z[0]
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
history.append((output,None))
comment_cnt=post_check['comment']
print(type(comment_cnt))
post_check['comment']=comment_cnt+1
#out_json = {'user':"",'datetime':current_time,'title':title,'blog':1,'comment':0,'reply':0,"prompt":prompt,"output":output}
#full_conv[-1]+=(output,)
full_conv.append((None,output,None))
html_out=load_html(full_conv,None)
#history.append((output,history))
#[textbox, chatbot, textbox, json, json, html]
return "",history,post_check,post_check,post_check,html_out
def reply_generate_OG(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
#def question_generate(prompt, history):
print("###############\nRUNNING BLOG POSTER REPLY\n###############\n")
seed = random.randint(1,1111111111111111)
agent=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0])
system_prompt=agent
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=seed,
)
#history.append((prompt,""))
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
client=client_z[0]
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
#history.append((output,history))
return output
def create_valid_filename(invalid_filename: str) -> str:
"""Converts invalid characters in a string to be suitable for a filename."""
invalid_filename.replace(" ","-")
valid_chars = '-'.join(invalid_filename.split())
allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-')
return ''.join(char for char in valid_chars if char in allowed_chars)
def load_html(inp,title):
ht=""
if inp:
for i,ea in enumerate(inp):
blog,comm,repl=ea
#print(f'outp:: {outp}')
#print(f'prom:: {prom}')
ht+=f"""<div class="div_box">"""
if blog:
ht+=f"""<pre class="bpost">{blog}</pre>"""
if comm:
ht+=f"""<pre class="resp1">{comm}</pre>"""
if repl:
ht+=f"""<pre class="resp2">{repl}</pre>"""
ht+=f"""</div>"""
with open('index.html','r') as h:
html=h.read()
html = html.replace("$body",f"{ht}")
html = html.replace("$title",f"{title}")
h.close()
return html
def load_html_OG(inp,title):
ht=""
if inp:
for i,ea in enumerate(inp):
outp,prom=ea
#print(f'outp:: {outp}')
#print(f'prom:: {prom}')
if i == 0:
ht+=f"""<div class="div_box">
<pre class="bpost">{outp}</pre>
<pre class="resp1">{prom}</pre>
</div>"""
else:
ht+=f"""<div class="div_box">
<pre class="resp2">{outp}</pre>
<pre class="resp2">{prom}</pre>
</div>"""
with open('index.html','r') as h:
html=h.read()
html = html.replace("$body",f"{ht}")
html = html.replace("$title",f"{title}")
h.close()
return html
def generate(prompt, history, post_check,full_conv, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0):
html_out=""
#main_point[0]=prompt
#print(datetime.datetime.now())
uid=uuid.uuid4()
current_time = str(datetime.datetime.now())
title=""
filename=create_valid_filename(f'{current_time}---{title}')
current_time=current_time.replace(":","-")
current_time=current_time.replace(".","-")
print (current_time)
agent=prompts.BLOG_POSTER
system_prompt=agent
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
hist_out=[]
sum_out=[]
json_hist={}
json_obj={}
full_conv=[]
post_cnt=1
if not post_check:
post_check={}
if not full_conv:
full_conv=[]
seed = random.randint(1,1111111111111111)
if not post_check:
print("writing blog")
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens2,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=seed,
)
if prompt.startswith(' \"'):
prompt=prompt.strip(' \"')
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
if len(formatted_prompt) < (40000):
print(len(formatted_prompt))
client=client_z[0]
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
#if history:
# yield history
for response in stream:
output += response.token.text
yield '', [(prompt,output)],post_check,full_conv,summary[0],json_obj, json_hist,html_out
if not title:
for line in output.split("\n"):
if "title" in line.lower() and ":" in line.lower():
title = line.split(":")[1]
print(f'title:: {title}')
filename=create_valid_filename(f'{current_time}---{title}')
out_json = {'user':"",'datetime':current_time,'title':title,'blog':1,'comment':0,'reply':0,"prompt":prompt,"output":output}
hist_out.append(out_json)
#try:
# for ea in
with open(f'{uid}.json', 'w') as f:
json_hist=json.dumps(hist_out, indent=4)
f.write(json_hist)
f.close()
upload_file(
path_or_fileobj =f"{uid}.json",
path_in_repo = f"book1/{filename}.json",
repo_id =f"{username}/{dataset_name}",
repo_type = "dataset",
token=token,
)
else:
formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history)
#current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0]
#filename=f'{filename}-{current_time}'
history = []
output = compress_history(formatted_prompt)
summary[0]=output
sum_json = {"summary":summary[0]}
sum_out.append(sum_json)
with open(f'{uid}-sum.json', 'w') as f:
json_obj=json.dumps(sum_out, indent=4)
f.write(json_obj)
f.close()
upload_file(
path_or_fileobj =f"{uid}-sum.json",
path_in_repo = f"book1/{filename}-summary.json",
repo_id =f"{username}/{dataset_name}",
repo_type = "dataset",
token=token,
)
#prompt = question_generate(output, history)
#main_point[0]=prompt
full_conv.append((output,None,None))
html_out=load_html(full_conv,title)
post_check={'user':"",'datetime':current_time,'title':title,'blog':1,'comment':0,'reply':0}
yield prompt, history,post_check,full_conv,summary[0],json_obj,json_hist,html_out
else:
print("passing blog")
def generate_OG(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0):
html_out=""
#main_point[0]=prompt
#print(datetime.datetime.now())
uid=uuid.uuid4()
current_time = str(datetime.datetime.now())
title=""
filename=create_valid_filename(f'{current_time}---{title}')
current_time=current_time.replace(":","-")
current_time=current_time.replace(".","-")
print (current_time)
agent=prompts.BLOG_POSTER
system_prompt=agent
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
hist_out=[]
sum_out=[]
json_hist={}
json_obj={}
full_conv=[]
post_cnt=1
while True:
seed = random.randint(1,1111111111111111)
if post_cnt==1:
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens2,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=seed,
)
if prompt.startswith(' \"'):
prompt=prompt.strip(' \"')
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
post_cnt+=1
else:
system_prompt=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0])
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens2,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=seed,
)
if prompt.startswith(' \"'):
prompt=prompt.strip(' \"')
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
print("###############\nRUNNING REPLY TO COMMENTER\n###############\n")
print (system_prompt)
if len(formatted_prompt) < (40000):
print(len(formatted_prompt))
client=client_z[0]
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
#if history:
# yield history
for response in stream:
output += response.token.text
yield '', [(prompt,output)],summary[0],json_obj, json_hist,html_out
if not title:
for line in output.split("\n"):
if "title" in line.lower() and ":" in line.lower():
title = line.split(":")[1]
print(f'title:: {title}')
filename=create_valid_filename(f'{current_time}---{title}')
out_json = {"prompt":prompt,"output":output}
prompt = question_generate(output, history)
#output += prompt
history.append((prompt,output))
print ( f'Prompt:: {len(prompt)}')
#print ( f'output:: {output}')
print ( f'history:: {len(formatted_prompt)}')
hist_out.append(out_json)
#try:
# for ea in
with open(f'{uid}.json', 'w') as f:
json_hist=json.dumps(hist_out, indent=4)
f.write(json_hist)
f.close()
upload_file(
path_or_fileobj =f"{uid}.json",
path_in_repo = f"book1/{filename}.json",
repo_id =f"{username}/{dataset_name}",
repo_type = "dataset",
token=token,
)
else:
formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history)
#current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0]
#filename=f'{filename}-{current_time}'
history = []
output = compress_history(formatted_prompt)
summary[0]=output
sum_json = {"summary":summary[0]}
sum_out.append(sum_json)
with open(f'{uid}-sum.json', 'w') as f:
json_obj=json.dumps(sum_out, indent=4)
f.write(json_obj)
f.close()
upload_file(
path_or_fileobj =f"{uid}-sum.json",
path_in_repo = f"book1/{filename}-summary.json",
repo_id =f"{username}/{dataset_name}",
repo_type = "dataset",
token=token,
)
prompt = question_generate(output, history)
main_point[0]=prompt
full_conv.append((output,prompt))
html_out=load_html(full_conv,title)
yield prompt, history, summary[0],json_obj,json_hist,html_out
return prompt, history, summary[0],json_obj,json_hist,html_out
with gr.Blocks() as app:
chat_handler=gr.State()
post_handler=gr.State()
html = gr.HTML()
chatbot=gr.Chatbot()
msg = gr.Textbox()
with gr.Row():
submit_b = gr.Button("Blog Post")
submit_c = gr.Button("Comment")
submit_r = gr.Button("OP Reply")
with gr.Row():
stop_b = gr.Button("Stop")
clear = gr.ClearButton([msg, chatbot])
with gr.Row():
m_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
sumbox=gr.Textbox("Summary", max_lines=100)
with gr.Column():
sum_out_box=gr.JSON(label="Summaries")
hist_out_box=gr.JSON(label="History")
m_choice.change(load_models,m_choice,[chatbot])
app.load(load_models,m_choice,[chatbot]).then(load_html,None,html)
sub_b = submit_b.click(generate, [msg,chatbot,post_handler,chat_handler,chat_handler,tokens],[msg,chatbot,post_handler,chat_handler,sumbox,sum_out_box,hist_out_box,html])
sub_c = submit_c.click(comment_generate, [msg,chatbot,post_handler,chat_handler],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
sub_r = submit_r.click(reply_generate, [msg,chatbot,post_handler,chat_handler],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html])
sub_e = msg.submit(generate, [msg,chatbot,post_handler,chat_handler,chat_handler,tokens],[msg,chatbot,post_handler,chat_handler,sumbox,sum_out_box,hist_out_box,html])
stop_b.click(None,None,None, cancels=[sub_b,sub_e,sub_c,sub_r])
app.queue(default_concurrency_limit=20).launch()