|
from huggingface_hub import InferenceClient, HfApi, upload_file |
|
import datetime |
|
import gradio as gr |
|
import random |
|
import prompts |
|
import json |
|
import uuid |
|
import os |
|
|
|
token=os.environ.get("HF_TOKEN") |
|
username="omnibus" |
|
dataset_name="tmp" |
|
api=HfApi(token="") |
|
VERBOSE=False |
|
|
|
history = [] |
|
hist_out= [] |
|
summary =[] |
|
main_point=[] |
|
summary.append("") |
|
main_point.append("") |
|
|
|
models=[ |
|
"google/gemma-7b", |
|
"google/gemma-7b-it", |
|
"google/gemma-2b", |
|
"google/gemma-2b-it", |
|
"meta-llama/Llama-2-7b-chat-hf", |
|
"codellama/CodeLlama-70b-Instruct-hf", |
|
"openchat/openchat-3.5-0106", |
|
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", |
|
"mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
"mistralai/Mixtral-8x7B-Instruct-v0.2", |
|
] |
|
|
|
client_z=[] |
|
|
|
def load_models(inp): |
|
if VERBOSE==True: |
|
print(type(inp)) |
|
print(inp) |
|
print(models[inp]) |
|
client_z.clear() |
|
client_z.append(InferenceClient(models[inp])) |
|
|
|
if "mistralai" in models[inp]: |
|
|
|
|
|
return gr.update(label=models[inp]) |
|
|
|
def format_prompt(message, history): |
|
prompt = "<s>" |
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
prompt += f"[INST] {message} [/INST]" |
|
return prompt |
|
|
|
agents =[ |
|
"COMMENTER", |
|
"BLOG_POSTER", |
|
"REPLY_TO_COMMENTER", |
|
"COMPRESS_HISTORY_PROMPT" |
|
] |
|
|
|
temperature=0.9 |
|
max_new_tokens=256 |
|
max_new_tokens2=4000 |
|
top_p=0.95 |
|
repetition_penalty=1.0, |
|
|
|
def compress_history(formatted_prompt): |
|
print("###############\nRUNNING COMPRESS HISTORY\n###############\n") |
|
seed = random.randint(1,1111111111111111) |
|
agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0]) |
|
|
|
system_prompt=agent |
|
temperature = 0.9 |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=1048, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
|
|
|
|
formatted_prompt = formatted_prompt |
|
client=client_z[0] |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
|
|
print(output) |
|
print(main_point[0]) |
|
return output |
|
|
|
|
|
def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,): |
|
|
|
print("###############\nRUNNING QUESTION GENERATOR\n###############\n") |
|
seed = random.randint(1,1111111111111111) |
|
agent=prompts.COMMENTER.format(focus=main_point[0]) |
|
system_prompt=agent |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
|
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
client=client_z[0] |
|
|
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
|
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
|
|
|
|
return output |
|
|
|
def blog_poster_reply(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): |
|
|
|
print("###############\nRUNNING BLOG POSTER REPLY\n###############\n") |
|
seed = random.randint(1,1111111111111111) |
|
agent=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0]) |
|
system_prompt=agent |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
|
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
client=client_z[0] |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
def create_valid_filename(invalid_filename: str) -> str: |
|
"""Converts invalid characters in a string to be suitable for a filename.""" |
|
invalid_filename.replace(" ","-") |
|
valid_chars = '-'.join(invalid_filename.split()) |
|
allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', |
|
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', |
|
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', |
|
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', |
|
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-') |
|
return ''.join(char for char in valid_chars if char in allowed_chars) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_html(inp,title): |
|
ht="" |
|
if inp: |
|
for i,ea in enumerate(inp): |
|
|
|
outp,prom=ea |
|
|
|
|
|
if i == 0: |
|
ht+=f"""<div class="div_box"> |
|
<pre class="bpost">{outp}</pre> |
|
<pre class="resp1">{prom}</pre> |
|
</div>""" |
|
else: |
|
ht+=f"""<div class="div_box"> |
|
<pre class="resp2">{outp}</pre> |
|
<pre class="resp2">{prom}</pre> |
|
</div>""" |
|
with open('index.html','r') as h: |
|
html=h.read() |
|
html = html.replace("$body",f"{ht}") |
|
html = html.replace("$title",f"{title}") |
|
h.close() |
|
return html |
|
|
|
|
|
|
|
def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0): |
|
html_out="" |
|
|
|
|
|
uid=uuid.uuid4() |
|
current_time = str(datetime.datetime.now()) |
|
title="" |
|
filename=create_valid_filename(f'{current_time}---{title}') |
|
|
|
current_time=current_time.replace(":","-") |
|
current_time=current_time.replace(".","-") |
|
print (current_time) |
|
agent=prompts.BLOG_POSTER |
|
system_prompt=agent |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
hist_out=[] |
|
sum_out=[] |
|
json_hist={} |
|
json_obj={} |
|
full_conv=[] |
|
post_cnt=1 |
|
while True: |
|
seed = random.randint(1,1111111111111111) |
|
if post_cnt==1: |
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens2, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
if prompt.startswith(' \"'): |
|
prompt=prompt.strip(' \"') |
|
|
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
|
|
post_cnt+=1 |
|
else: |
|
system_prompt=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0]) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens2, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
if prompt.startswith(' \"'): |
|
prompt=prompt.strip(' \"') |
|
|
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
print("###############\nRUNNING REPLY TO COMMENTER\n###############\n") |
|
print (system_prompt) |
|
if len(formatted_prompt) < (40000): |
|
print(len(formatted_prompt)) |
|
|
|
client=client_z[0] |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
|
|
|
|
for response in stream: |
|
output += response.token.text |
|
yield '', [(prompt,output)],summary[0],json_obj, json_hist,html_out |
|
|
|
if not title: |
|
for line in output.split("\n"): |
|
if "title" in line.lower() and ":" in line.lower(): |
|
title = line.split(":")[1] |
|
print(f'title:: {title}') |
|
filename=create_valid_filename(f'{current_time}---{title}') |
|
|
|
out_json = {"prompt":prompt,"output":output} |
|
|
|
prompt = question_generate(output, history) |
|
|
|
history.append((prompt,output)) |
|
print ( f'Prompt:: {len(prompt)}') |
|
|
|
print ( f'history:: {len(formatted_prompt)}') |
|
hist_out.append(out_json) |
|
|
|
|
|
with open(f'{uid}.json', 'w') as f: |
|
json_hist=json.dumps(hist_out, indent=4) |
|
f.write(json_hist) |
|
f.close() |
|
|
|
upload_file( |
|
path_or_fileobj =f"{uid}.json", |
|
path_in_repo = f"book1/{filename}.json", |
|
repo_id =f"{username}/{dataset_name}", |
|
repo_type = "dataset", |
|
token=token, |
|
) |
|
else: |
|
formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history) |
|
|
|
|
|
|
|
history = [] |
|
output = compress_history(formatted_prompt) |
|
summary[0]=output |
|
sum_json = {"summary":summary[0]} |
|
sum_out.append(sum_json) |
|
with open(f'{uid}-sum.json', 'w') as f: |
|
json_obj=json.dumps(sum_out, indent=4) |
|
f.write(json_obj) |
|
f.close() |
|
upload_file( |
|
path_or_fileobj =f"{uid}-sum.json", |
|
path_in_repo = f"book1/{filename}-summary.json", |
|
repo_id =f"{username}/{dataset_name}", |
|
repo_type = "dataset", |
|
token=token, |
|
) |
|
|
|
|
|
prompt = question_generate(output, history) |
|
main_point[0]=prompt |
|
full_conv.append((output,prompt)) |
|
|
|
|
|
html_out=load_html(full_conv,title) |
|
yield prompt, history, summary[0],json_obj,json_hist,html_out |
|
return prompt, history, summary[0],json_obj,json_hist,html_out |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
html = gr.HTML() |
|
|
|
chatbot=gr.Chatbot() |
|
msg = gr.Textbox() |
|
with gr.Row(): |
|
submit_b = gr.Button() |
|
stop_b = gr.Button("Stop") |
|
clear = gr.ClearButton([msg, chatbot]) |
|
with gr.Row(): |
|
m_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True) |
|
tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens") |
|
|
|
sumbox=gr.Textbox("Summary", max_lines=100) |
|
with gr.Column(): |
|
sum_out_box=gr.JSON(label="Summaries") |
|
hist_out_box=gr.JSON(label="History") |
|
|
|
|
|
m_choice.change(load_models,m_choice,[chatbot]) |
|
app.load(load_models,m_choice,[chatbot]).then(load_html,None,html) |
|
|
|
sub_b = submit_b.click(generate, [msg,chatbot,tokens],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html]) |
|
sub_e = msg.submit(generate, [msg, chatbot,tokens], [msg, chatbot,sumbox,sum_out_box,hist_out_box,html]) |
|
stop_b.click(None,None,None, cancels=[sub_b,sub_e]) |
|
|
|
app.queue(default_concurrency_limit=20).launch() |