test_space / app.py
gstaff's picture
Add back mega pipeline.
87e24e1
import base64
import pathlib
import re
import time
from io import BytesIO
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageChops, ImageDraw
from fastai.callback.core import Callback
from fastai.learner import *
from fastai.torch_core import TitledStr
from html2image import Html2Image
# from min_dalle import MinDalle
from torch import tensor, Tensor, float16, float32
from torch.distributions import Transform
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
# These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see:
# https://docs.fast.ai/learner.html#load_learner
from transformers import GPT2TokenizerFast
import os
AUTH_TOKEN = os.environ.get('AUTH_TOKEN')
# update requirements.txt with:
# C:\Users\Grant\PycharmProjects\test_space\venv\Scripts\pip3.exe freeze > requirements.txt
# Huggingface Spaces have 16GB RAM and 8 CPU cores
# See https://huggingface.co/docs/hub/spaces-overview#hardware-resources
pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
def tokenize(text):
toks = tokenizer.tokenize(text)
return tensor(tokenizer.convert_tokens_to_ids(toks))
class TransformersTokenizer(Transform):
def __init__(self, tokenizer): self.tokenizer = tokenizer
def encodes(self, x):
return x if isinstance(x, Tensor) else tokenize(x)
def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))
class DropOutput(Callback):
def after_pred(self): self.learn.pred = self.pred[0]
# initialize only once
# Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU
# NOTE as of 2022-11-13 min-dalle is broken, switch to using a stable diffusion model for images
# model = MinDalle(
# models_root='./pretrained',
# dtype=float32,
# device='cpu',
# is_mega=True,
# is_reusable=True
# )
# Download pipeline, but overwrite scheduler
# Consider DPMSolverMultistepScheduler once added to diffusers
from diffusers import EulerAncestralDiscreteScheduler
scheduler = EulerAncestralDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler",
use_auth_token=AUTH_TOKEN)
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega",
torch_dtype=torch.float32,
scheduler=scheduler, use_auth_token=AUTH_TOKEN)
# pipe.enable_attention_slicing()
# pipeline.to("cuda")
def gen_image(prompt):
prompt = f"{prompt}, fantasy painting by Greg Rutkowski"
# See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py
# Hugging Space faces seems to run out of memory if grads are not disabled
# torch.set_grad_enabled(False)
print(f'RUNNING gen_image with prompt: {prompt}')
images = pipeline.text2img(prompt, width=256, height=256, num_inference_steps=20).images
# images = model.generate_images(
# text=prompt,
# seed=-1,
# grid_size=1, # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images
# is_seamless=False,
# temperature=1,
# top_k=256,
# supercondition_factor=16,
# is_verbose=True
# )
print('COMPLETED GENERATION')
# images = images.to('cpu').numpy()
# images = images.astype(np.uint8)
# return Image.fromarray(images[0])
return images[0]
gpu = False
# init only once
learner = load_learner('export.pkl',
cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`
def parse_monster_description(name, text):
match = re.search(r"Description: (.*)", text)
if not match:
return f"{name} is a monster."
description = match.group(1)
print(description.split('.')[0])
return description.split('.')[0]
def gen_monster_text(name):
prompt = f"Name: {name}\r\n"
print(f'GENERATING MONSTER TEXT with prompt: {prompt}')
prompt_ids = tokenizer.encode(prompt)
if gpu:
inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU
else:
inp = tensor(prompt_ids)[None]
preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True,
repetition_penalty=1.2)
result = tokenizer.decode(preds[0].cpu().numpy())
result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
print(f'GENERATING MONSTER COMPLETE')
print(result)
return result
def extract_text_for_header(text, header):
match = re.search(fr"{header}: (.*)", text)
if match is None:
return ''
return match.group(1)
def remove_section(html, html_class):
match = re.search(f'<li class="{html_class}"([\w\W])*?li>', html)
if match is not None:
html = html.replace(match.group(0), '')
return html
def format_monster_card(monster_text, image_data):
print('FORMATTING MONSTER TEXT')
# see giffyglyph's monster maker https://giffyglyph.com/monstermaker/app/
# Different Formatting style examples and some json export formats
card = pathlib.Path('monsterMakerTemplate.html').read_text()
if not isinstance(image_data, (bytes, bytearray)):
card = card.replace('{image_data}', f'{image_data}')
else:
card = card.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}')
name = extract_text_for_header(monster_text, 'Name')
card = card.replace('{name}', name)
monster_type = extract_text_for_header(monster_text, 'Type')
card = card.replace('{monster_type}', monster_type)
armor_class = extract_text_for_header(monster_text, 'Armor Class')
card = card.replace('{armor_class}', armor_class)
hit_points = extract_text_for_header(monster_text, 'Hit Points')
card = card.replace('{hit_points}', hit_points)
speed = extract_text_for_header(monster_text, 'Speed')
card = card.replace('{speed}', speed)
str_stat = extract_text_for_header(monster_text, 'STR')
card = card.replace('{str_stat}', str_stat)
dex_stat = extract_text_for_header(monster_text, 'DEX')
card = card.replace('{dex_stat}', dex_stat)
con_stat = extract_text_for_header(monster_text, 'CON')
card = card.replace('{con_stat}', con_stat)
int_stat = extract_text_for_header(monster_text, 'INT')
card = card.replace('{int_stat}', int_stat)
wis_stat = extract_text_for_header(monster_text, 'WIS')
card = card.replace('{wis_stat}', wis_stat)
cha_stat = extract_text_for_header(monster_text, 'CHA')
card = card.replace('{cha_stat}', cha_stat)
saving_throws = extract_text_for_header(monster_text, 'Saving Throws')
card = card.replace('{saving_throws}', saving_throws)
if not saving_throws:
card = remove_section(card, 'monster-saves')
skills = extract_text_for_header(monster_text, 'Skills')
card = card.replace('{skills}', skills)
if not skills:
card = remove_section(card, 'monster-skills')
damage_vulnerabilities = extract_text_for_header(monster_text, 'Damage Vulnerabilities')
card = card.replace('{damage_vulnerabilities}', damage_vulnerabilities)
if not damage_vulnerabilities:
card = remove_section(card, 'monster-vulnerabilities')
damage_resistances = extract_text_for_header(monster_text, 'Damage Resistances')
card = card.replace('{damage_resistances}', damage_resistances)
if not damage_resistances:
card = remove_section(card, 'monster-resistances')
damage_immunities = extract_text_for_header(monster_text, 'Damage Immunities')
card = card.replace('{damage_immunities}', damage_immunities)
if not damage_immunities:
card = remove_section(card, 'monster-immunities')
condition_immunities = extract_text_for_header(monster_text, 'Condition Immunities')
card = card.replace('{condition_immunities}', condition_immunities)
if not condition_immunities:
card = remove_section(card, 'monster-conditions')
senses = extract_text_for_header(monster_text, 'Senses')
card = card.replace('{senses}', senses)
if not senses:
card = remove_section(card, 'monster-senses')
languages = extract_text_for_header(monster_text, 'Languages')
card = card.replace('{languages}', languages)
if not languages:
card = remove_section(card, 'monster-languages')
challenge = extract_text_for_header(monster_text, 'Challenge')
card = card.replace('{challenge}', challenge)
if not challenge:
card = remove_section(card, 'monster-challenge')
description = extract_text_for_header(monster_text, 'Description')
card = card.replace('{description}', description)
match = re.search(r"Passives:\n([\w\W]*)", monster_text)
if match is None:
passives = ''
else:
passives = match.group(1)
p = passives.split(':')
if len(p) > 1:
p = ":".join(p)
p = p.split('\n')
passives_data = ''
for x in p:
x = x.split(':')
if len(x) > 1:
trait = x[0]
if trait == "Passives":
continue
if 'Action' in trait:
break
detail = ":".join(x[1:])
passives_data += f'<div class="monster-trait"><p><span class="name">{trait}</span> <span class="detail">{detail}</span></p></div>'
card = card.replace('{passives}', passives_data)
else:
card = card.replace('{passives}', f'<div class="monster-trait"><p>{passives}</p></div>')
match = re.search(r"Actions:\n([\w\W]*)", monster_text)
if match is None:
actions = ''
else:
actions = match.group(1)
a = actions.split(':')
if len(a) > 1:
a = ":".join(a)
a = a.split('\n')
actions_data = ''
for x in a:
x = x.split(':')
if len(x) > 1:
action = x[0]
if action == "Actions":
continue
if 'Passive' in action:
break
detail = ":".join(x[1:])
actions_data += f'<div class="monster-action"><p><span class="name">{action}</span> <span class="detail">{detail}</span></p></div>'
card = card.replace('{actions}', actions_data)
else:
card = card.replace('{actions}', f'<div class="monster-action"><p>{actions}</p></div>')
# TODO: Legendary actions, reactions, make column count for format an option (1 or 2 column layout)
card = card.replace('Melee or Ranged Weapon Attack:', '<i>Melee or Ranged Weapon Attack:</i>')
card = card.replace('Melee Weapon Attack:', '<i>Melee Weapon Attack:</i>')
card = card.replace('Ranged Weapon Attack:', '<i>Ranged Weapon Attack:</i>')
card = card.replace('Hit:', '<i>Hit:</i>')
print('FORMATTING MONSTER TEXT COMPLETE')
return card
def pil_to_base64(image):
print('CONVERTING PIL IMAGE TO BASE64 STRING')
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE')
return img_str
hti = Html2Image(output_path='rendered_cards')
def trim(im, border):
bg = Image.new(im.mode, im.size, border)
diff = ImageChops.difference(im, bg)
bbox = diff.getbbox()
if bbox:
return im.crop(bbox)
def crop_background(image):
white = (255, 255, 255)
ImageDraw.floodfill(image, (image.size[0] - 1, 0), white, thresh=50)
image = trim(image, white)
return image
def html_to_png(html):
print('CONVERTING HTML CARD TO PNG IMAGE')
paths = hti.screenshot(html_str=html, css_file="monstermaker.css", save_as="test.png", size=(800, 1440))
path = paths[0]
print('OPENING IMAGE FROM FILE')
img = Image.open(path).convert("RGB")
print('CROPPING BACKGROUND')
img = crop_background(img)
print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE')
return img
def run(name: str) -> (Image, str, Image, str):
start = time.time()
print(f'BEGINNING RUN FOR {name}')
if not name:
placeholder_image = Image.new(mode="RGB", size=(256, 256))
return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, ''
text = gen_monster_text(name)
description = parse_monster_description(name, text)
pil = gen_image(description)
image_data = pil_to_base64(pil)
card_html = format_monster_card(text, image_data)
card_image = html_to_png(card_html)
end = time.time()
print(f'RUN COMPLETED IN {int(end - start)} seconds')
return card_image, text, pil, card_html
app_description = (
"""
# Create your own D&D monster!
Enter a name, click Submit, and wait for about 4 minutes to see the result.
""").strip()
input_box = gr.Textbox(label="Enter a monster name", placeholder="Jabberwock")
output_monster_card = gr.Image(label="Monster Card", type='pil', value="examples/jabberwock_card.png")
output_text_box = gr.Textbox(label="Monster Text", value=pathlib.Path("examples/jabberwock.txt").read_text('utf-8'))
output_monster_image = gr.Image(label="Monster Image", type='pil', value="examples/jabberwock.png")
output_monster_html = gr.HTML(label="Monster HTML", visible=False, show_label=False)
x = gr.components.Textbox()
iface = gr.Interface(title="MonsterGen", theme="default", description=app_description, fn=run, inputs=[input_box],
outputs=[output_monster_card, output_text_box, output_monster_image, output_monster_html])
iface.launch()
# TODO: Add examples, larger language model?, document process, log silences, "Passives" => "Traits", log timestamps
# Fine tune dalle-mini? https://blog.paperspace.com/dalle-mini/
# API works, assuming query takes no longer than 30 seconds (504 gateway timeout)
# Looks like API page improvements are in progress: https://github.com/gradio-app/gradio/issues/1325
# Example code below:
# import requests
# r = requests.post(url='https://hf.space/embed/gstaff/test_space/+/api/predict', json={"data": [""]},
# timeout=None)
# print(r.json())
# Looks like Huggingface uses the queue push api, then polls for status:
# fetch("https://hf.space/embed/gstaff/test_space/api/queue/push/", {
# "headers": {
# "accept": "*/*",
# "accept-language": "en-US,en;q=0.9",
# "content-type": "application/json",
# "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"",
# "sec-ch-ua-mobile": "?0",
# "sec-ch-ua-platform": "\"Windows\"",
# "sec-fetch-dest": "empty",
# "sec-fetch-mode": "cors",
# "sec-fetch-site": "same-origin"
# },
# "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light",
# "referrerPolicy": "strict-origin-when-cross-origin",
# "body": "{\"fn_index\":0,\"data\":[\"Jabberwock\"],\"action\":\"predict\",\"session_hash\":\"v9ehgfho3p\"}",
# "method": "POST",
# "mode": "cors",
# "credentials": "omit"
# });
# fetch("https://hf.space/embed/gstaff/test_space/api/queue/status/", {
# "headers": {
# "accept": "*/*",
# "accept-language": "en-US,en;q=0.9",
# "content-type": "application/json",
# "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"",
# "sec-ch-ua-mobile": "?0",
# "sec-ch-ua-platform": "\"Windows\"",
# "sec-fetch-dest": "empty",
# "sec-fetch-mode": "cors",
# "sec-fetch-site": "same-origin"
# },
# "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light",
# "referrerPolicy": "strict-origin-when-cross-origin",
# "body": "{\"hash\":\"09f5369a7a414169aa48948bad5fd93d\"}",
# "method": "POST",
# "mode": "cors",
# "credentials": "omit"
# });