from huggingface_hub import InferenceClient, HfApi, upload_file import datetime import gradio as gr import random import prompts import json import uuid import os token=os.environ.get("HF_TOKEN") username="omnibus" dataset_name="tmp" api=HfApi(token="") client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") history = [] hist_out= [] summary =[] main_point=[] summary.append("") main_point.append("") def format_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt agents =[ "COMMENTER", "BLOG_POSTER", "REPLY_TO_COMMENTER", "COMPRESS_HISTORY_PROMPT" ] temperature=0.9 max_new_tokens=256 max_new_tokens2=10480 top_p=0.95 repetition_penalty=1.0, def compress_history(formatted_prompt): print("###############\nRUNNING COMPRESS HISTORY\n###############\n") seed = random.randint(1,1111111111111111) agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0]) system_prompt=agent temperature = 0.9 if temperature < 1e-2: temperature = 1e-2 generate_kwargs = dict( temperature=temperature, max_new_tokens=10480, top_p=0.95, repetition_penalty=1.0, do_sample=True, seed=seed, ) #history.append((prompt,"")) #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) formatted_prompt = formatted_prompt stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text #history.append((output,history)) print(output) print(main_point[0]) return output def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1028, top_p=0.95, repetition_penalty=1.0,): #def question_generate(prompt, history): print("###############\nRUNNING QUESTION GENERATOR\n###############\n") seed = random.randint(1,1111111111111111) agent=prompts.COMMENTER.format(focus=main_point[0]) system_prompt=agent temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=seed, ) #history.append((prompt,"")) formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text #history.append((output,history)) return output def blog_poster_reply(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): #def question_generate(prompt, history): print("###############\nRUNNING BLOG POSTER REPLY\n###############\n") seed = random.randint(1,1111111111111111) agent=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0]) system_prompt=agent temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=seed, ) #history.append((prompt,"")) formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text #history.append((output,history)) return output def create_valid_filename(invalid_filename: str) -> str: """Converts invalid characters in a string to be suitable for a filename.""" invalid_filename.replace(" ","-") valid_chars = '-'.join(invalid_filename.split()) allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-') return ''.join(char for char in valid_chars if char in allowed_chars) def load_html(inp,title): ht="" if inp: for i,ea in enumerate(inp): outp,prom=ea #print(f'outp:: {outp}') #print(f'prom:: {prom}') if i == 0: ht+=f"""
{outp}
{prom}
""" else: ht+=f"""
{outp}
{prom}
""" with open('index.html','r') as h: html=h.read() html = html.replace("$body",f"{ht}") html = html.replace("$title",f"{title}") h.close() return html def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,): html_out="" #main_point[0]=prompt #print(datetime.datetime.now()) uid=uuid.uuid4() current_time = str(datetime.datetime.now()) title="" filename=create_valid_filename(f'{current_time}---{title}') current_time=current_time.replace(":","-") current_time=current_time.replace(".","-") print (current_time) agent=prompts.BLOG_POSTER system_prompt=agent temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) hist_out=[] sum_out=[] json_hist={} json_obj={} full_conv=[] post_cnt=1 while True: seed = random.randint(1,1111111111111111) if post_cnt==1: generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens2, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=seed, ) if prompt.startswith(' \"'): prompt=prompt.strip(' \"') formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) post_cnt+=1 else: system_prompt=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0]) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens2, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=seed, ) if prompt.startswith(' \"'): prompt=prompt.strip(' \"') formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) print("###############\nRUNNING REPLY TO COMMENTER\n###############\n") print (system_prompt) if len(formatted_prompt) < (40000): print(len(formatted_prompt)) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" #if history: # yield history for response in stream: output += response.token.text yield '', [(prompt,output)],summary[0],json_obj, json_hist,html_out if not title: for line in output.split("\n"): if "title" in line.lower() and ":" in line.lower(): title = line.split(":")[1] print(f'title:: {title}') filename=create_valid_filename(f'{current_time}---{title}') out_json = {"prompt":prompt,"output":output} prompt = question_generate(output, history) #output += prompt history.append((prompt,output)) print ( f'Prompt:: {len(prompt)}') #print ( f'output:: {output}') print ( f'history:: {len(formatted_prompt)}') hist_out.append(out_json) #try: # for ea in with open(f'{uid}.json', 'w') as f: json_hist=json.dumps(hist_out, indent=4) f.write(json_hist) f.close() upload_file( path_or_fileobj =f"{uid}.json", path_in_repo = f"book1/{filename}.json", repo_id =f"{username}/{dataset_name}", repo_type = "dataset", token=token, ) else: formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history) #current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0] #filename=f'{filename}-{current_time}' history = [] output = compress_history(formatted_prompt) summary[0]=output sum_json = {"summary":summary[0]} sum_out.append(sum_json) with open(f'{uid}-sum.json', 'w') as f: json_obj=json.dumps(sum_out, indent=4) f.write(json_obj) f.close() upload_file( path_or_fileobj =f"{uid}-sum.json", path_in_repo = f"book1/{filename}-summary.json", repo_id =f"{username}/{dataset_name}", repo_type = "dataset", token=token, ) prompt = question_generate(output, history) main_point[0]=prompt full_conv.append((output,prompt)) html_out=load_html(full_conv,title) yield prompt, history, summary[0],json_obj,json_hist,html_out return prompt, history, summary[0],json_obj,json_hist,html_out with gr.Blocks() as app: html = gr.HTML() chatbot=gr.Chatbot() msg = gr.Textbox() with gr.Row(): submit_b = gr.Button() stop_b = gr.Button("Stop") clear = gr.ClearButton([msg, chatbot]) sumbox=gr.Textbox("Summary", max_lines=100) with gr.Column(): sum_out_box=gr.JSON(label="Summaries") hist_out_box=gr.JSON(label="History") sub_b = submit_b.click(generate, [msg,chatbot],[msg,chatbot,sumbox,sum_out_box,hist_out_box,html]) sub_e = msg.submit(generate, [msg, chatbot], [msg, chatbot,sumbox,sum_out_box,hist_out_box,html]) stop_b.click(None,None,None, cancels=[sub_b,sub_e]) app.load(load_html,None,html) app.queue(default_concurrency_limit=20).launch()