MagicGen / app.py
gstaff's picture
Add output examples, use EulerA sampler and reduce inference steps to 20 to reduce CPU runtime from 20 minutes.
ebf6ce9
raw
history blame
12.1 kB
import gradio as gr
import pathlib
import base64
import re
import time
from io import BytesIO
import imgkit
import os
from PIL import Image
from fastai.callback.core import Callback
from fastai.learner import *
from fastai.torch_core import TitledStr
from torch import tensor, Tensor
from torch.distributions import Transform
import random
# These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see:
# https://docs.fast.ai/learner.html#load_learner
from transformers import GPT2TokenizerFast
import torch
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
gpu = False
AUTH_TOKEN = os.environ.get('AUTH_TOKEN')
if gpu:
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16", use_auth_token=AUTH_TOKEN)
scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler = scheduler
pipeline.to("cuda")
else:
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
custom_pipeline="stable_diffusion_mega", use_auth_token=AUTH_TOKEN)
scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler = scheduler
# Huggingface Spaces have 16GB RAM and 8 CPU cores
# See https://huggingface.co/docs/hub/spaces-overview#hardware-resources
pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
def tokenize(text):
toks = tokenizer.tokenize(text)
return tensor(tokenizer.convert_tokens_to_ids(toks))
class TransformersTokenizer(Transform):
def __init__(self, tokenizer): self.tokenizer = tokenizer
def encodes(self, x):
return x if isinstance(x, Tensor) else tokenize(x)
def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))
class DropOutput(Callback):
def after_pred(self): self.learn.pred = self.pred[0]
def gen_card_text(name):
if name == '':
prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}"
else:
prompt = f"Name: {name}\r\n"
print(f'GENERATING CARD TEXT with prompt: {prompt}')
prompt_ids = tokenizer.encode(prompt)
if gpu:
inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU
else:
inp = tensor(prompt_ids)[None]
preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True,
repetition_penalty=1.2)
result = tokenizer.decode(preds[0].cpu().numpy())
result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '')
print(f'GENERATING CARD COMPLETE')
print(result)
if name == '':
pattern = re.compile('Name: (.*)')
name = pattern.findall(result)[0]
return name, result
# init only once
learner = load_learner('./colab-data-test/export.pkl',
cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116`
pathlib.Path('card_data').mkdir(parents=True, exist_ok=True)
pathlib.Path('card_images').mkdir(parents=True, exist_ok=True)
pathlib.Path('card_html').mkdir(parents=True, exist_ok=True)
pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True)
def run(name):
start = time.time()
print(f'BEGINNING RUN FOR {name}')
name, text = gen_card_text(name)
save_name = get_savename('card_data', name, 'txt')
pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8')
pattern = re.compile('Type: (.*)')
card_type = pattern.findall(text)[0]
prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski"
print(f"GENERATING IMAGE FOR {prompt_template}")
# Regarding sizing see https://huggingface.co/blog/stable_diffusion#:~:text=When%20choosing%20image%20sizes%2C%20we%20advise%20the%20following%3A
images = pipeline.text2img(prompt_template, width=512, height=368, num_inference_steps=20).images
card_image = None
for image in images:
save_name = get_savename('card_images', name, 'png')
image.save(f"card_images/{save_name}")
card_image = image
image_data = pil_to_base64(card_image)
html = format_html(text, image_data)
save_name = get_savename('card_html', name, 'html')
pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8')
rendered = html_to_png(name, html)
end = time.time()
print(f'RUN COMPLETED IN {int(end - start)} seconds')
return rendered, text, card_image, html
def pil_to_base64(image):
print('CONVERTING PIL IMAGE TO BASE64 STRING')
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE')
return img_str
def format_html(text, image_data):
template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8')
if "['U']" in text:
template = template.replace("{card_color}", 'style="background-color:#5a73ab"')
elif "['W']" in text:
template = template.replace("{card_color}", 'style="background-color:#f0e3d0"')
elif "['G']" in text:
template = template.replace("{card_color}", 'style="background-color:#325433"')
elif "['B']" in text:
template = template.replace("{card_color}", 'style="background-color:#1a1b1e"')
elif "['R']" in text:
template = template.replace("{card_color}", 'style="background-color:#c2401c"')
elif "Type: Land" in text:
template = template.replace("{card_color}", 'style="background-color:#aa8c71"')
elif "Type: Artifact" in text:
template = template.replace("{card_color}", 'style="background-color:#9ba7bc"')
else:
template = template.replace("{card_color}", 'style="background-color:#edd99d"')
pattern = re.compile('Name: (.*)')
name = pattern.findall(text)[0]
template = template.replace("{name}", name)
pattern = re.compile('ManaCost: (.*)')
mana_cost = pattern.findall(text)[0]
if mana_cost == "None":
template = template.replace("{mana_cost}", '<i class="ms ms-cost" style="visibility: hidden"></i>')
else:
symbols = []
for c in mana_cost:
if c in {"{", "}"}:
continue
else:
symbols.append(c.lower())
formatted_symbols = []
for s in symbols:
formatted_symbols.append(f'<i class="ms ms-{s} ms-cost ms-shadow"></i>')
template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1]))
if not isinstance(image_data, (bytes, bytearray)):
template = template.replace('{image_data}', f'{image_data}')
else:
template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}')
pattern = re.compile('Type: (.*)')
card_type = pattern.findall(text)[0]
template = template.replace("{card_type}", card_type)
if len(card_type) > 30:
template = template.replace("{type_size}", "16")
else:
template = template.replace("{type_size}", "18")
pattern = re.compile('Rarity: (.*)')
rarity = pattern.findall(text)[0]
template = template.replace("{rarity}", f"ss-{rarity}")
pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL)
card_text = pattern.findall(text)[0]
text_lines = []
for line in card_text.splitlines():
line = line.replace('{T}', '<i class="ms ms-tap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>')
line = line.replace('{UT}', '<i class="ms ms-untap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>')
line = line.replace('{E}', '<i class="ms ms-instant ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>')
line = re.sub(r"{(.*?)}", r'<i class="ms ms-\1 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), line)
line = re.sub(r"ms-(.)/(.)", r'<i class="ms ms-\1\2 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), line)
line = line.replace('(', '(<i>').replace(')', '</i>)')
text_lines.append(f"<p>{line}</p>")
template = template.replace("{card_text}", "\n".join(text_lines))
pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL)
flavor_text = pattern.findall(text)
if flavor_text:
flavor_text = flavor_text[0]
flavor_text_lines = []
for line in flavor_text.splitlines():
flavor_text_lines.append(f"<p>{line}</p>")
template = template.replace("{flavor_text}", "<blockquote>" + "\n".join(flavor_text_lines) + "</blockquote>")
else:
template = template.replace("{flavor_text}", "")
if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3:
template = template.replace("{text_size}", '16')
template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>',
'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;"></i>')
else:
template = template.replace("{text_size}", '18')
pattern = re.compile('Power: (.*)')
power = pattern.findall(text)
if power:
power = power[0]
if not power:
template = template.replace("{power_toughness}", "")
pattern = re.compile('Toughness: (.*)')
toughness = pattern.findall(text)[0]
template = template.replace("{power_toughness}", f'<header class="powerToughness"><div><h2 style="font-family: \'Beleren\';font-size: 19px;">{power}/{toughness}</h2></div></header>')
else:
template = template.replace("{power_toughness}", "")
pathlib.Path("test.html").write_text(template, encoding='utf-8')
return template
def get_savename(directory, name, extension):
save_name = f"{name}.{extension}"
i = 1
while os.path.exists(os.path.join(directory, save_name)):
save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}"
i += 1
return save_name
def html_to_png(card_name, html):
save_name = get_savename('rendered_cards', card_name, 'png')
print('CONVERTING HTML CARD TO PNG IMAGE')
path = os.path.join('rendered_cards', save_name)
try:
css = ['./colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css', './colab-data-test/css/mtg_custom.css']
imgkit.from_string(html, path, {"xvfb": ""}, css=css)
except:
pass
print('OPENING IMAGE FROM FILE')
img = Image.open(path)
print('CROPPING BACKGROUND')
area = (0, 50, 400, 600)
cropped_img = img.crop(area)
cropped_img.resize((400, 550))
cropped_img.save(os.path.join(path))
print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE')
return cropped_img.convert('RGB')
app_description = (
"""
# Create your own Magic: The Gathering cards!
Enter a name, click Submit, it may take up to 10 minutes to run on the free CPU only instance.
""").strip()
input_box = gr.Textbox(label="Enter a card name", placeholder="Firebolt")
rendered_card = gr.Image(label="Card", type='pil', value="examples/card.png")
output_text_box = gr.Textbox(label="Card Text", value=pathlib.Path("examples/text.txt").read_text('utf-8'))
output_card_image = gr.Image(label="Card Image", type='pil', value="examples/image.png")
output_card_html = gr.HTML(label="Card HTML", visible=False, show_label=False)
x = gr.components.Textbox()
iface = gr.Interface(title="MagicGen", theme="default", description=app_description, fn=run, inputs=[input_box],
outputs=[rendered_card, output_text_box, output_card_image, output_card_html])
iface.launch()