Spaces:
Runtime error
Runtime error
import json | |
import random | |
import string | |
import requests | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline | |
from pingpong import PingPong | |
from pingpong.pingpong import PPManager | |
from pingpong.pingpong import PromptFmt | |
from pingpong.pingpong import UIFmt | |
from pingpong.gradio import GradioChatUIFmt | |
from fpdf import FPDF | |
class PDF(FPDF): | |
def header(self): | |
# Arial bold 15 | |
self.set_font('Arial', 'B', 15) | |
# Calculate width of title and position | |
w = self.get_string_width(self.title) + 6 | |
self.set_x((210 - w) / 2) | |
# Colors of frame, background and text | |
self.set_draw_color(255, 255, 255) | |
self.set_fill_color(255, 255, 255) | |
# self.set_text_color(220, 50, 50) | |
# Thickness of frame (1 mm) | |
self.set_line_width(1) | |
# Title | |
self.cell(w, 9, self.title, 1, 1, 'C', 1) | |
# Line break | |
self.ln(10) | |
if self.art is not None: | |
self.image(self.art, x=self.w/2.0-25, w=50) | |
self.ln(10) | |
def footer(self): | |
# Position at 1.5 cm from bottom | |
self.set_y(-15) | |
# Arial italic 8 | |
self.set_font('Arial', 'I', 8) | |
# Text color in gray | |
self.set_text_color(128) | |
# Page number | |
self.cell(0, 10, 'Page ' + str(self.page_no()), 0, 0, 'C') | |
def chapter_title(self, num, label): | |
# Arial 12 | |
self.set_font('Arial', '', 12) | |
# Background color | |
self.set_fill_color(200, 220, 255) | |
# Title | |
self.cell(0, 6, 'Chapter %d : %s' % (num, label), 0, 1, 'L', 1) | |
# Line break | |
self.ln(4) | |
def chapter_body(self, content): | |
# Times 12 | |
self.set_font('Times', '', 12) | |
# Output justified text | |
self.multi_cell(0, 5, content) | |
# Line break | |
self.ln() | |
# Mention in italics | |
self.set_font('', 'I') | |
def print_chapter(self, content): | |
self.add_page() | |
self.chapter_body(content) | |
class LLaMA2ChatPromptFmt(PromptFmt): | |
def ctx(cls, context): | |
if context is None or context == "": | |
return "" | |
else: | |
return f"""<<SYS>> | |
{context} | |
<</SYS>> | |
""" | |
def prompt(cls, pingpong, truncate_size): | |
ping = pingpong.ping[:truncate_size] | |
pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size] | |
return f"""[INST] {ping} [/INST] {pong}""" | |
class LLaMA2ChatPPManager(PPManager): | |
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None): | |
if to_idx == -1 or to_idx >= len(self.pingpongs): | |
to_idx = len(self.pingpongs) | |
results = fmt.ctx(self.ctx) | |
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]): | |
results += fmt.prompt(pingpong, truncate_size=truncate_size) | |
return results | |
class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager): | |
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): | |
if to_idx == -1 or to_idx >= len(self.pingpongs): | |
to_idx = len(self.pingpongs) | |
results = [] | |
for pingpong in self.pingpongs[from_idx:to_idx]: | |
results.append(fmt.ui(pingpong)) | |
return results | |
TOKEN = os.getenv('HF_TOKEN') | |
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"nota-ai/bk-sdm-small", torch_dtype=torch.float16 | |
) | |
STYLES = """ | |
.left-panel { | |
min-width: min(290px, 100%) !important; | |
} | |
.small-big { | |
font-size: 12pt !important; | |
} | |
.small-big-textarea > label > textarea { | |
font-size: 12pt !important; | |
} | |
.highlighted-text { | |
background: yellow; | |
overflow-wrap: break-word; | |
} | |
.no-gap { | |
gap: 0px !important; | |
} | |
.group-border { | |
padding: 10px; | |
border-width: 1px; | |
border-radius: 10px; | |
border-color: gray; | |
border-style: dashed; | |
} | |
.control-label-font { | |
font-size: 13pt !important; | |
} | |
.control-button { | |
background: none !important; | |
border-color: #69ade2 !important; | |
border-width: 2px !important; | |
color: #69ade2 !important; | |
} | |
.center { | |
text-align: center; | |
} | |
.right { | |
text-align: right; | |
} | |
.no-label { | |
padding: 0px !important; | |
} | |
.no-label > label > span { | |
display: none; | |
} | |
.no-label-chatbot { | |
border: none !important; | |
box-shadow: none !important; | |
height: 520px !important; | |
} | |
.no-label-chatbot > div > div:nth-child(1) { | |
display: none; | |
} | |
.no-label-image > div:nth-child(2) { | |
display: none; | |
} | |
.left-margin-30 { | |
padding-left: 30px !important; | |
} | |
.left { | |
text-align: left !important; | |
} | |
.alt-button { | |
color: gray !important; | |
border-width: 1px !important; | |
background: none !important; | |
border-color: gray !important; | |
text-align: justify !important; | |
} | |
.white-text { | |
color: #000 !important; | |
} | |
""" | |
def id_generator(size=6, chars=string.ascii_uppercase + string.digits): | |
return ''.join(random.choice(chars) for _ in range(size)) | |
def get_new_ppm(ping): | |
ppm = LLaMA2ChatPPManager() | |
ppm.ctx = """\ | |
You are a helpful, respectful and honest writing helper. Always write stories that suites to query. | |
You DO NOT give explanation but just stories. For instance, do not say such as "Sure! Here's a short paragraph to start a short story:""" | |
ppm.add_pingpong(PingPong(ping, '')) | |
return ppm | |
def get_new_ppm_for_chat(): | |
ppm = GradioLLaMA2ChatPPManager() | |
return ppm | |
def gen_text(prompt, hf_model='meta-llama/Llama-2-70b-chat-hf', hf_token=None, parameters=None): | |
if hf_token is None: | |
raise ValueError("Hugging Face Token is not set") | |
if parameters is None: | |
parameters = { | |
'max_new_tokens': 512, | |
'do_sample': True, | |
'return_full_text': False, | |
'temperature': 1.0, | |
'top_k': 50, | |
# 'top_p': 1.0, | |
'repetition_penalty': 1.2 | |
} | |
url = f'https://api-inference.huggingface.co/models/{hf_model}' | |
headers={ | |
'Authorization': f'Bearer {hf_token}', | |
'Content-type': 'application/json' | |
} | |
data = { | |
'inputs': prompt, | |
'stream': False, | |
'options': { | |
'use_cache': False, | |
}, | |
'parameters': parameters | |
} | |
r = requests.post( | |
url, | |
headers=headers, | |
data=json.dumps(data) | |
) | |
if r.reason != 'OK': | |
raise ValueError("Response other than 200") | |
return json.loads(r.content.decode("utf-8"))[0]['generated_text'] | |
def gen_art(editor, cover_art_image, gen_cover_art_prompt): | |
if gen_cover_art_prompt.strip() == "": | |
ppm = get_new_ppm(f"""describe the story below as a movie poster. give me the caption ONLY. | |
-------------------------------- | |
{editor}""") | |
cover_art_prompt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
return [ | |
cover_art_image, | |
cover_art_prompt | |
] | |
else: | |
global pipe | |
pipe = pipe.to("cuda") | |
return [ | |
pipe(gen_cover_art_prompt).images[0], | |
gen_cover_art_prompt | |
] | |
def generate_pdf(title, editor, concept_art): | |
tmp_filename = id_generator() | |
if concept_art is not None: | |
im = Image.fromarray(concept_art) | |
im.save(f"{tmp_filename}.png") | |
pdf = PDF() | |
pdf.title = "Untitled" if title.strip() == "" else title | |
pdf.art = None if concept_art is None else f"{tmp_filename}.png" | |
pdf.print_chapter(editor) | |
pdf.output(f'{tmp_filename}.pdf', 'F') | |
return ( | |
gr.update(value=f'{tmp_filename}.pdf', visible=True), | |
" " | |
) | |
def select(editor, evt: gr.SelectData): | |
return [ | |
evt.value, | |
evt.index[0], | |
evt.index[1] | |
] | |
def get_gen_txt(title, editor, prompt, only_gen_text=False): | |
if editor.strip() == '': | |
ppm = get_new_ppm(f'Write a short paragraph to start a short story titled "{title}" for me') | |
else: | |
ppm = get_new_ppm(f"""{prompt} | |
-------------------------------- | |
{editor}""") | |
try: | |
txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
if only_gen_text: | |
return txt + "\n\n" | |
else: | |
return editor + txt + "\n\n" | |
except ValueError as e: | |
print(f"something went wrong - {e}") | |
return editor | |
def gen_txt(title, editor, prompt): | |
return [ | |
get_gen_txt(title, editor, "Write the next paragraph based on the following stories so far." if prompt.strip() == "" else prompt), | |
0, | |
gr.update(interactive=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
] | |
def chat_gen(editor, chat_txt, chatbot, ppm, regen=False): | |
ppm.ctx = f"""\ | |
You are a helpful, respectful and honest assistant. | |
you must consider multi-turn conversations. | |
Answer to questions based on the written stories so far as below | |
---------------- | |
{editor} | |
""" | |
if regen: | |
last_pingpong = ppm.pop_pingpong() | |
chat_txt = last_pingpong.ping | |
ppm.add_pingpong(PingPong(chat_txt, '')) | |
try: | |
txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
ppm.add_pong(txt) | |
except ValueError as e: | |
print(f"something went wrong - {e}") | |
return [ | |
"", | |
ppm.build_uis(), | |
ppm | |
] | |
def chat(editor, chat_txt, chatbot, ppm): | |
return chat_gen(editor, chat_txt, chatbot, ppm, regen=False) | |
def regen_chat(editor, chat_txt, chatbot, ppm): | |
return chat_gen(editor, chat_txt, chatbot, ppm, regen=True) | |
def get_new_ppm_for_range(): | |
ppm = LLaMA2ChatPPManager() | |
ppm.ctx = """\ | |
You are a helpful, respectful and honest writing helper. Always write text that suites to query. | |
You DO NOT give explanation but just stories. DO NOT say such as 'Sure! Here's a short paragraph to start a short story:' or 'Sure, here is a revised version of ....:' | |
""" | |
return ppm | |
def replace_sel(editor, replace_type, selected_text, sel_index_from, sel_index_to): | |
ppm = get_new_ppm_for_range() | |
ping = f"""replace {selected_text} in a single {replace_type} based on the story below | |
---------------- | |
{editor} | |
""" | |
ppm.add_pingpong(PingPong(ping, '')) | |
try: | |
txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
ppm.add_pong(txt) | |
except ValueError as e: | |
print(f"something went wrong - {e}") | |
return [ | |
f"{editor[:sel_index_from]} {txt} {editor[sel_index_to:]}", | |
"", | |
0, | |
0 | |
] | |
def gen_alt(title, editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3): | |
if num_enabled_alts < 3: | |
gen_txt = get_gen_txt(title, editor, "Write the next paragraph based on the following stories so far.", only_gen_text=True) | |
return [ | |
min(num_enabled_alts+1, 3), | |
gr.update(interactive=False if num_enabled_alts >=2 else True), | |
gr.update(visible=True if num_enabled_alts >=0 else False), | |
gr.update(value=gen_txt if num_enabled_alts == 0 else alt_btn1), | |
gr.update(visible=True if num_enabled_alts >=1 else False), | |
gr.update(value=gen_txt if num_enabled_alts == 1 else alt_btn2), | |
gr.update(visible=True if num_enabled_alts >=2 else False), | |
gr.update(value=gen_txt if num_enabled_alts == 2 else alt_btn3), | |
" ", | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
] | |
def fill_with_gen(alt_txt, editor): | |
return [ | |
editor + alt_txt, | |
0, | |
gr.update(interactive=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False) | |
] | |
with gr.Blocks(css=STYLES) as demo: | |
num_enabled_alts = gr.State(0) | |
sel_index_from = gr.State(0) | |
sel_index_to = gr.State(0) | |
chat_history = gr.State(get_new_ppm_for_chat()) | |
gr.Markdown("# Co-writing with AI", elem_classes=['center']) | |
gr.Markdown( | |
"This application is designed for you to collaborate with LLM to co-write stories. It is inspired by [Wordcraft project](https://wordcraft-writers-workshop.appspot.com/) from Google's PAIR and Magenta teams. " | |
"This application built on [Gradio](https://www.gradio.app), and the underlying text generation is powered by [Hugging Face Inference API](https://huggingface.co/inference-api). The text generation model might" | |
"be changed over time, but [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) is selected for now.", | |
elem_classes=['center', 'small-big']) | |
progress_bar = gr.Textbox(elem_classes=['no-label']) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
editor = gr.Textbox(lines=32, max_lines=32, elem_classes=['no-label', 'small-big-textarea']) | |
word_counter = gr.Markdown("0 words", elem_classes=['right']) | |
with gr.Column(scale=1): | |
with gr.Tab("Control"): | |
with gr.Column(elem_classes=['group-border']): | |
gr.Markdown('### title') | |
title = gr.Textbox("pokemon training story", elem_classes=['no-label']) | |
with gr.Column(elem_classes=['group-border']): | |
with gr.Column(): | |
gr.Markdown("For instant generation and concatenation, use `generate text` button. " | |
"Want to explore alternative choices? use `generate alternatives` button.") | |
with gr.Accordion("longer guideline", open=False): | |
gr.Markdown("`generate text` button generate continued text and attach it to the end. " | |
"on the other hand, `generate alternatives` button generate alternate texts " | |
"up to 3 and let you choose one of them. In both cases, **Write the next paragraph based on " | |
"the following stories so far.** is the default prompt. If you want to try your own designed " | |
"prompt, enter it in the textbox below.") | |
prompt = gr.Textbox(placeholder="design your own prompt", elem_classes=['no-label']) | |
with gr.Row(): | |
gen_btn = gr.Button("generate text", elem_classes=['control-label-font', 'control-button']) | |
gen_alt_btn = gr.Button("generate alternatives", elem_classes=['control-label-font', 'control-button']) | |
with gr.Column(): | |
with gr.Row(visible=False) as first_alt: | |
gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
alt_btn1 = gr.Button("Alternative 1", elem_classes=['alt-button'], scale=8) | |
with gr.Row(visible=False) as second_alt: | |
gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
alt_btn2 = gr.Button("Alternative 2", elem_classes=['alt-button'], scale=8) | |
with gr.Row(visible=False) as third_alt: | |
gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
alt_btn3 = gr.Button("Alternative 3", elem_classes=['alt-button'], scale=8) | |
with gr.Column(elem_classes=['group-border']): | |
with gr.Row(): | |
selected_text = gr.Markdown("Selected text will be displayed in this area", elem_classes=['highlighted-text']) | |
with gr.Row(): | |
with gr.Column(elem_classes=['no-gap']): | |
replace_sel_btn = gr.Button("replace selection", elem_classes=['control-label-font', 'control-button']) | |
replace_type = gr.Dropdown(choices=['word', 'sentense', 'phrase', 'paragraph'], value='sentense', interactive=True, elem_classes=['no-label']) | |
with gr.Tab("Chatting"): | |
chatbot = gr.Chatbot([], elem_classes=['no-label-chatbot']) | |
chat_txt = gr.Textbox(placeholder="enter question", elem_classes=['no-label']) | |
with gr.Row(): | |
clear_btn = gr.Button("clear", elem_classes=['control-label-font', 'control-button']) | |
regen_btn = gr.Button("regenerate", elem_classes=['control-label-font', 'control-button']) | |
with gr.Tab("Exporting"): | |
cover_art = gr.Image(interactive=False, elem_classes=['no-label-image']) | |
gen_cover_art_prompt = gr.Textbox(lines=5, max_lines=5, elem_classes=['no-label']) | |
# toggle between "generate prompt for cover art" and "generate cover art" | |
gen_cover_art_btn = gr.Button("generate prompt for cover art", elem_classes=['control-label-font', 'control-button']) | |
gen_pdf_btn = gr.Button("export as PDF", elem_classes=['control-label-font', 'control-button']) | |
pdf_file = gr.File(visible=False) | |
gen_pdf_btn.click( | |
lambda t, e, c: generate_pdf(t, e, c), | |
inputs=[title, editor, cover_art], | |
outputs=[pdf_file, progress_bar] | |
) | |
gen_cover_art_btn.click( | |
gen_art, | |
inputs=[editor, cover_art, gen_cover_art_prompt], | |
outputs=[cover_art, gen_cover_art_prompt] | |
) | |
gen_cover_art_prompt.change( | |
fn=None, | |
inputs=[gen_cover_art_prompt], | |
outputs=[gen_cover_art_btn], | |
_js="(t) => t.trim() == '' ? 'generate prompt for cover art' : 'generate cover art'" | |
) | |
editor.change( | |
fn=None, | |
inputs=[editor], | |
outputs=[word_counter, selected_text], | |
_js="(e) => [e.split(/\s+/).length, '']" | |
) | |
editor.select( | |
fn=select, | |
inputs=[editor], | |
outputs=[selected_text, sel_index_from, sel_index_to], | |
show_progress='minimal' | |
) | |
gen_btn.click( | |
lambda: ( | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
), | |
inputs=None, | |
outputs=[gen_btn, gen_alt_btn, replace_sel_btn] | |
).then( | |
fn=gen_txt, | |
inputs=[title, editor, prompt], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt, gen_btn, replace_sel_btn] | |
) | |
gen_alt_btn.click( | |
lambda: ( | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
), | |
inputs=None, | |
outputs=[gen_btn, gen_alt_btn, replace_sel_btn] | |
).then( | |
fn=gen_alt, | |
inputs=[title, editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3], | |
outputs=[num_enabled_alts, gen_alt_btn, first_alt, alt_btn1, second_alt, alt_btn2, third_alt, alt_btn3, progress_bar, gen_btn, replace_sel_btn], | |
) | |
alt_btn1.click( | |
fn=fill_with_gen, | |
inputs=[alt_btn1, editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
alt_btn2.click( | |
fn=fill_with_gen, | |
inputs=[alt_btn2, editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
alt_btn3.click( | |
fn=fill_with_gen, | |
inputs=[alt_btn3, editor], | |
outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
) | |
replace_sel_btn.click( | |
fn=replace_sel, | |
inputs=[editor, replace_type, selected_text, sel_index_from, sel_index_to], | |
outputs=[editor, selected_text, sel_index_from, sel_index_to], | |
show_progress='minimal' | |
) | |
chat_txt.submit( | |
fn=chat, | |
inputs=[editor, chat_txt, chatbot, chat_history], | |
outputs=[chat_txt, chatbot, chat_history] | |
) | |
regen_btn.click( | |
fn=regen_chat, | |
inputs=[editor, chat_txt, chatbot, chat_history], | |
outputs=[chat_txt, chatbot, chat_history] | |
) | |
demo.launch() |