Spaces:
Runtime error
Runtime error
import base64 | |
import pathlib | |
import re | |
import time | |
from io import BytesIO | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image, ImageChops, ImageDraw | |
from fastai.callback.core import Callback | |
from fastai.learner import * | |
from fastai.torch_core import TitledStr | |
from html2image import Html2Image | |
# from min_dalle import MinDalle | |
from torch import tensor, Tensor, float16, float32 | |
from torch.distributions import Transform | |
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
# These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see: | |
# https://docs.fast.ai/learner.html#load_learner | |
from transformers import GPT2TokenizerFast | |
import os | |
AUTH_TOKEN = os.environ.get('AUTH_TOKEN') | |
# update requirements.txt with: | |
# C:\Users\Grant\PycharmProjects\test_space\venv\Scripts\pip3.exe freeze > requirements.txt | |
# Huggingface Spaces have 16GB RAM and 8 CPU cores | |
# See https://huggingface.co/docs/hub/spaces-overview#hardware-resources | |
pretrained_weights = 'gpt2' | |
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights) | |
def tokenize(text): | |
toks = tokenizer.tokenize(text) | |
return tensor(tokenizer.convert_tokens_to_ids(toks)) | |
class TransformersTokenizer(Transform): | |
def __init__(self, tokenizer): self.tokenizer = tokenizer | |
def encodes(self, x): | |
return x if isinstance(x, Tensor) else tokenize(x) | |
def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy())) | |
class DropOutput(Callback): | |
def after_pred(self): self.learn.pred = self.pred[0] | |
# initialize only once | |
# Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU | |
# NOTE as of 2022-11-13 min-dalle is broken, switch to using a stable diffusion model for images | |
# model = MinDalle( | |
# models_root='./pretrained', | |
# dtype=float32, | |
# device='cpu', | |
# is_mega=True, | |
# is_reusable=True | |
# ) | |
# Download pipeline, but overwrite scheduler | |
# Consider DPMSolverMultistepScheduler once added to diffusers | |
from diffusers import EulerAncestralDiscreteScheduler | |
scheduler = EulerAncestralDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler", | |
use_auth_token=AUTH_TOKEN) | |
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", | |
torch_dtype=torch.float32, | |
scheduler=scheduler, use_auth_token=AUTH_TOKEN) | |
# pipe.enable_attention_slicing() | |
# pipeline.to("cuda") | |
def gen_image(prompt): | |
prompt = f"{prompt}, fantasy painting by Greg Rutkowski" | |
# See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py | |
# Hugging Space faces seems to run out of memory if grads are not disabled | |
# torch.set_grad_enabled(False) | |
print(f'RUNNING gen_image with prompt: {prompt}') | |
images = pipeline.text2img(prompt, width=256, height=256, num_inference_steps=20).images | |
# images = model.generate_images( | |
# text=prompt, | |
# seed=-1, | |
# grid_size=1, # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images | |
# is_seamless=False, | |
# temperature=1, | |
# top_k=256, | |
# supercondition_factor=16, | |
# is_verbose=True | |
# ) | |
print('COMPLETED GENERATION') | |
# images = images.to('cpu').numpy() | |
# images = images.astype(np.uint8) | |
# return Image.fromarray(images[0]) | |
return images[0] | |
gpu = False | |
# init only once | |
learner = load_learner('export.pkl', | |
cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116` | |
def parse_monster_description(name, text): | |
match = re.search(r"Description: (.*)", text) | |
if not match: | |
return f"{name} is a monster." | |
description = match.group(1) | |
print(description.split('.')[0]) | |
return description.split('.')[0] | |
def gen_monster_text(name): | |
prompt = f"Name: {name}\r\n" | |
print(f'GENERATING MONSTER TEXT with prompt: {prompt}') | |
prompt_ids = tokenizer.encode(prompt) | |
if gpu: | |
inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU | |
else: | |
inp = tensor(prompt_ids)[None] | |
preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True, | |
repetition_penalty=1.2) | |
result = tokenizer.decode(preds[0].cpu().numpy()) | |
result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') | |
print(f'GENERATING MONSTER COMPLETE') | |
print(result) | |
return result | |
def extract_text_for_header(text, header): | |
match = re.search(fr"{header}: (.*)", text) | |
if match is None: | |
return '' | |
return match.group(1) | |
def remove_section(html, html_class): | |
match = re.search(f'<li class="{html_class}"([\w\W])*?li>', html) | |
if match is not None: | |
html = html.replace(match.group(0), '') | |
return html | |
def format_monster_card(monster_text, image_data): | |
print('FORMATTING MONSTER TEXT') | |
# see giffyglyph's monster maker https://giffyglyph.com/monstermaker/app/ | |
# Different Formatting style examples and some json export formats | |
card = pathlib.Path('monsterMakerTemplate.html').read_text() | |
if not isinstance(image_data, (bytes, bytearray)): | |
card = card.replace('{image_data}', f'{image_data}') | |
else: | |
card = card.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') | |
name = extract_text_for_header(monster_text, 'Name') | |
card = card.replace('{name}', name) | |
monster_type = extract_text_for_header(monster_text, 'Type') | |
card = card.replace('{monster_type}', monster_type) | |
armor_class = extract_text_for_header(monster_text, 'Armor Class') | |
card = card.replace('{armor_class}', armor_class) | |
hit_points = extract_text_for_header(monster_text, 'Hit Points') | |
card = card.replace('{hit_points}', hit_points) | |
speed = extract_text_for_header(monster_text, 'Speed') | |
card = card.replace('{speed}', speed) | |
str_stat = extract_text_for_header(monster_text, 'STR') | |
card = card.replace('{str_stat}', str_stat) | |
dex_stat = extract_text_for_header(monster_text, 'DEX') | |
card = card.replace('{dex_stat}', dex_stat) | |
con_stat = extract_text_for_header(monster_text, 'CON') | |
card = card.replace('{con_stat}', con_stat) | |
int_stat = extract_text_for_header(monster_text, 'INT') | |
card = card.replace('{int_stat}', int_stat) | |
wis_stat = extract_text_for_header(monster_text, 'WIS') | |
card = card.replace('{wis_stat}', wis_stat) | |
cha_stat = extract_text_for_header(monster_text, 'CHA') | |
card = card.replace('{cha_stat}', cha_stat) | |
saving_throws = extract_text_for_header(monster_text, 'Saving Throws') | |
card = card.replace('{saving_throws}', saving_throws) | |
if not saving_throws: | |
card = remove_section(card, 'monster-saves') | |
skills = extract_text_for_header(monster_text, 'Skills') | |
card = card.replace('{skills}', skills) | |
if not skills: | |
card = remove_section(card, 'monster-skills') | |
damage_vulnerabilities = extract_text_for_header(monster_text, 'Damage Vulnerabilities') | |
card = card.replace('{damage_vulnerabilities}', damage_vulnerabilities) | |
if not damage_vulnerabilities: | |
card = remove_section(card, 'monster-vulnerabilities') | |
damage_resistances = extract_text_for_header(monster_text, 'Damage Resistances') | |
card = card.replace('{damage_resistances}', damage_resistances) | |
if not damage_resistances: | |
card = remove_section(card, 'monster-resistances') | |
damage_immunities = extract_text_for_header(monster_text, 'Damage Immunities') | |
card = card.replace('{damage_immunities}', damage_immunities) | |
if not damage_immunities: | |
card = remove_section(card, 'monster-immunities') | |
condition_immunities = extract_text_for_header(monster_text, 'Condition Immunities') | |
card = card.replace('{condition_immunities}', condition_immunities) | |
if not condition_immunities: | |
card = remove_section(card, 'monster-conditions') | |
senses = extract_text_for_header(monster_text, 'Senses') | |
card = card.replace('{senses}', senses) | |
if not senses: | |
card = remove_section(card, 'monster-senses') | |
languages = extract_text_for_header(monster_text, 'Languages') | |
card = card.replace('{languages}', languages) | |
if not languages: | |
card = remove_section(card, 'monster-languages') | |
challenge = extract_text_for_header(monster_text, 'Challenge') | |
card = card.replace('{challenge}', challenge) | |
if not challenge: | |
card = remove_section(card, 'monster-challenge') | |
description = extract_text_for_header(monster_text, 'Description') | |
card = card.replace('{description}', description) | |
match = re.search(r"Passives:\n([\w\W]*)", monster_text) | |
if match is None: | |
passives = '' | |
else: | |
passives = match.group(1) | |
p = passives.split(':') | |
if len(p) > 1: | |
p = ":".join(p) | |
p = p.split('\n') | |
passives_data = '' | |
for x in p: | |
x = x.split(':') | |
if len(x) > 1: | |
trait = x[0] | |
if trait == "Passives": | |
continue | |
if 'Action' in trait: | |
break | |
detail = ":".join(x[1:]) | |
passives_data += f'<div class="monster-trait"><p><span class="name">{trait}</span> <span class="detail">{detail}</span></p></div>' | |
card = card.replace('{passives}', passives_data) | |
else: | |
card = card.replace('{passives}', f'<div class="monster-trait"><p>{passives}</p></div>') | |
match = re.search(r"Actions:\n([\w\W]*)", monster_text) | |
if match is None: | |
actions = '' | |
else: | |
actions = match.group(1) | |
a = actions.split(':') | |
if len(a) > 1: | |
a = ":".join(a) | |
a = a.split('\n') | |
actions_data = '' | |
for x in a: | |
x = x.split(':') | |
if len(x) > 1: | |
action = x[0] | |
if action == "Actions": | |
continue | |
if 'Passive' in action: | |
break | |
detail = ":".join(x[1:]) | |
actions_data += f'<div class="monster-action"><p><span class="name">{action}</span> <span class="detail">{detail}</span></p></div>' | |
card = card.replace('{actions}', actions_data) | |
else: | |
card = card.replace('{actions}', f'<div class="monster-action"><p>{actions}</p></div>') | |
# TODO: Legendary actions, reactions, make column count for format an option (1 or 2 column layout) | |
card = card.replace('Melee or Ranged Weapon Attack:', '<i>Melee or Ranged Weapon Attack:</i>') | |
card = card.replace('Melee Weapon Attack:', '<i>Melee Weapon Attack:</i>') | |
card = card.replace('Ranged Weapon Attack:', '<i>Ranged Weapon Attack:</i>') | |
card = card.replace('Hit:', '<i>Hit:</i>') | |
print('FORMATTING MONSTER TEXT COMPLETE') | |
return card | |
def pil_to_base64(image): | |
print('CONVERTING PIL IMAGE TO BASE64 STRING') | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') | |
return img_str | |
hti = Html2Image(output_path='rendered_cards') | |
def trim(im, border): | |
bg = Image.new(im.mode, im.size, border) | |
diff = ImageChops.difference(im, bg) | |
bbox = diff.getbbox() | |
if bbox: | |
return im.crop(bbox) | |
def crop_background(image): | |
white = (255, 255, 255) | |
ImageDraw.floodfill(image, (image.size[0] - 1, 0), white, thresh=50) | |
image = trim(image, white) | |
return image | |
def html_to_png(html): | |
print('CONVERTING HTML CARD TO PNG IMAGE') | |
paths = hti.screenshot(html_str=html, css_file="monstermaker.css", save_as="test.png", size=(800, 1440)) | |
path = paths[0] | |
print('OPENING IMAGE FROM FILE') | |
img = Image.open(path).convert("RGB") | |
print('CROPPING BACKGROUND') | |
img = crop_background(img) | |
print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') | |
return img | |
def run(name: str) -> (Image, str, Image, str): | |
start = time.time() | |
print(f'BEGINNING RUN FOR {name}') | |
if not name: | |
placeholder_image = Image.new(mode="RGB", size=(256, 256)) | |
return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, '' | |
text = gen_monster_text(name) | |
description = parse_monster_description(name, text) | |
pil = gen_image(description) | |
image_data = pil_to_base64(pil) | |
card_html = format_monster_card(text, image_data) | |
card_image = html_to_png(card_html) | |
end = time.time() | |
print(f'RUN COMPLETED IN {int(end - start)} seconds') | |
return card_image, text, pil, card_html | |
app_description = ( | |
""" | |
# Create your own D&D monster! | |
Enter a name, click Submit, and wait for about 4 minutes to see the result. | |
""").strip() | |
input_box = gr.Textbox(label="Enter a monster name", placeholder="Jabberwock") | |
output_monster_card = gr.Image(label="Monster Card", type='pil', value="examples/jabberwock_card.png") | |
output_text_box = gr.Textbox(label="Monster Text", value=pathlib.Path("examples/jabberwock.txt").read_text('utf-8')) | |
output_monster_image = gr.Image(label="Monster Image", type='pil', value="examples/jabberwock.png") | |
output_monster_html = gr.HTML(label="Monster HTML", visible=False, show_label=False) | |
x = gr.components.Textbox() | |
iface = gr.Interface(title="MonsterGen", theme="default", description=app_description, fn=run, inputs=[input_box], | |
outputs=[output_monster_card, output_text_box, output_monster_image, output_monster_html]) | |
iface.launch() | |
# TODO: Add examples, larger language model?, document process, log silences, "Passives" => "Traits", log timestamps | |
# Fine tune dalle-mini? https://blog.paperspace.com/dalle-mini/ | |
# API works, assuming query takes no longer than 30 seconds (504 gateway timeout) | |
# Looks like API page improvements are in progress: https://github.com/gradio-app/gradio/issues/1325 | |
# Example code below: | |
# import requests | |
# r = requests.post(url='https://hf.space/embed/gstaff/test_space/+/api/predict', json={"data": [""]}, | |
# timeout=None) | |
# print(r.json()) | |
# Looks like Huggingface uses the queue push api, then polls for status: | |
# fetch("https://hf.space/embed/gstaff/test_space/api/queue/push/", { | |
# "headers": { | |
# "accept": "*/*", | |
# "accept-language": "en-US,en;q=0.9", | |
# "content-type": "application/json", | |
# "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"", | |
# "sec-ch-ua-mobile": "?0", | |
# "sec-ch-ua-platform": "\"Windows\"", | |
# "sec-fetch-dest": "empty", | |
# "sec-fetch-mode": "cors", | |
# "sec-fetch-site": "same-origin" | |
# }, | |
# "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light", | |
# "referrerPolicy": "strict-origin-when-cross-origin", | |
# "body": "{\"fn_index\":0,\"data\":[\"Jabberwock\"],\"action\":\"predict\",\"session_hash\":\"v9ehgfho3p\"}", | |
# "method": "POST", | |
# "mode": "cors", | |
# "credentials": "omit" | |
# }); | |
# fetch("https://hf.space/embed/gstaff/test_space/api/queue/status/", { | |
# "headers": { | |
# "accept": "*/*", | |
# "accept-language": "en-US,en;q=0.9", | |
# "content-type": "application/json", | |
# "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"", | |
# "sec-ch-ua-mobile": "?0", | |
# "sec-ch-ua-platform": "\"Windows\"", | |
# "sec-fetch-dest": "empty", | |
# "sec-fetch-mode": "cors", | |
# "sec-fetch-site": "same-origin" | |
# }, | |
# "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light", | |
# "referrerPolicy": "strict-origin-when-cross-origin", | |
# "body": "{\"hash\":\"09f5369a7a414169aa48948bad5fd93d\"}", | |
# "method": "POST", | |
# "mode": "cors", | |
# "credentials": "omit" | |
# }); |