|
import gradio as gr |
|
|
|
import pathlib |
|
import base64 |
|
import re |
|
import time |
|
from io import BytesIO |
|
|
|
import imgkit |
|
import os |
|
from PIL import Image |
|
from fastai.callback.core import Callback |
|
from fastai.learner import * |
|
from fastai.torch_core import TitledStr |
|
from torch import tensor, Tensor |
|
from torch.distributions import Transform |
|
import random |
|
|
|
|
|
|
|
from transformers import GPT2TokenizerFast |
|
|
|
import torch |
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
|
|
|
gpu = False |
|
|
|
AUTH_TOKEN = os.environ.get('AUTH_TOKEN') |
|
|
|
|
|
if gpu: |
|
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16", use_auth_token=AUTH_TOKEN) |
|
scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.scheduler = scheduler |
|
pipeline.to("cuda") |
|
else: |
|
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", |
|
custom_pipeline="stable_diffusion_mega", use_auth_token=AUTH_TOKEN) |
|
scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.scheduler = scheduler |
|
|
|
|
|
|
|
|
|
pretrained_weights = 'gpt2' |
|
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights) |
|
|
|
|
|
def tokenize(text): |
|
toks = tokenizer.tokenize(text) |
|
return tensor(tokenizer.convert_tokens_to_ids(toks)) |
|
|
|
|
|
class TransformersTokenizer(Transform): |
|
def __init__(self, tokenizer): self.tokenizer = tokenizer |
|
|
|
def encodes(self, x): |
|
return x if isinstance(x, Tensor) else tokenize(x) |
|
|
|
def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy())) |
|
|
|
|
|
class DropOutput(Callback): |
|
def after_pred(self): self.learn.pred = self.pred[0] |
|
|
|
|
|
def gen_card_text(name): |
|
if name == '': |
|
prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}" |
|
else: |
|
prompt = f"Name: {name}\r\n" |
|
print(f'GENERATING CARD TEXT with prompt: {prompt}') |
|
prompt_ids = tokenizer.encode(prompt) |
|
if gpu: |
|
inp = tensor(prompt_ids)[None].cuda() |
|
else: |
|
inp = tensor(prompt_ids)[None] |
|
preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True, |
|
repetition_penalty=1.2) |
|
result = tokenizer.decode(preds[0].cpu().numpy()) |
|
result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') |
|
print(f'GENERATING CARD COMPLETE') |
|
print(result) |
|
if name == '': |
|
pattern = re.compile('Name: (.*)') |
|
name = pattern.findall(result)[0] |
|
return name, result |
|
|
|
|
|
|
|
learner = load_learner('./colab-data-test/export.pkl', |
|
cpu=not gpu) |
|
|
|
pathlib.Path('card_data').mkdir(parents=True, exist_ok=True) |
|
pathlib.Path('card_images').mkdir(parents=True, exist_ok=True) |
|
pathlib.Path('card_html').mkdir(parents=True, exist_ok=True) |
|
pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def run(name): |
|
start = time.time() |
|
print(f'BEGINNING RUN FOR {name}') |
|
name, text = gen_card_text(name) |
|
save_name = get_savename('card_data', name, 'txt') |
|
pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8') |
|
|
|
pattern = re.compile('Type: (.*)') |
|
card_type = pattern.findall(text)[0] |
|
prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" |
|
print(f"GENERATING IMAGE FOR {prompt_template}") |
|
|
|
images = pipeline.text2img(prompt_template, width=512, height=368, num_inference_steps=20).images |
|
card_image = None |
|
for image in images: |
|
save_name = get_savename('card_images', name, 'png') |
|
image.save(f"card_images/{save_name}") |
|
card_image = image |
|
|
|
image_data = pil_to_base64(card_image) |
|
html = format_html(text, image_data) |
|
save_name = get_savename('card_html', name, 'html') |
|
pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8') |
|
rendered = html_to_png(name, html) |
|
|
|
end = time.time() |
|
print(f'RUN COMPLETED IN {int(end - start)} seconds') |
|
return rendered, text, card_image, html |
|
|
|
|
|
def pil_to_base64(image): |
|
print('CONVERTING PIL IMAGE TO BASE64 STRING') |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') |
|
return img_str |
|
|
|
|
|
def format_html(text, image_data): |
|
template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8') |
|
if "['U']" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#5a73ab"') |
|
elif "['W']" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') |
|
elif "['G']" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#325433"') |
|
elif "['B']" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') |
|
elif "['R']" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#c2401c"') |
|
elif "Type: Land" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#aa8c71"') |
|
elif "Type: Artifact" in text: |
|
template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') |
|
else: |
|
template = template.replace("{card_color}", 'style="background-color:#edd99d"') |
|
pattern = re.compile('Name: (.*)') |
|
name = pattern.findall(text)[0] |
|
template = template.replace("{name}", name) |
|
pattern = re.compile('ManaCost: (.*)') |
|
mana_cost = pattern.findall(text)[0] |
|
if mana_cost == "None": |
|
template = template.replace("{mana_cost}", '<i class="ms ms-cost" style="visibility: hidden"></i>') |
|
else: |
|
symbols = [] |
|
for c in mana_cost: |
|
if c in {"{", "}"}: |
|
continue |
|
else: |
|
symbols.append(c.lower()) |
|
formatted_symbols = [] |
|
for s in symbols: |
|
formatted_symbols.append(f'<i class="ms ms-{s} ms-cost ms-shadow"></i>') |
|
template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) |
|
if not isinstance(image_data, (bytes, bytearray)): |
|
template = template.replace('{image_data}', f'{image_data}') |
|
else: |
|
template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') |
|
pattern = re.compile('Type: (.*)') |
|
card_type = pattern.findall(text)[0] |
|
template = template.replace("{card_type}", card_type) |
|
if len(card_type) > 30: |
|
template = template.replace("{type_size}", "16") |
|
else: |
|
template = template.replace("{type_size}", "18") |
|
pattern = re.compile('Rarity: (.*)') |
|
rarity = pattern.findall(text)[0] |
|
template = template.replace("{rarity}", f"ss-{rarity}") |
|
pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL) |
|
card_text = pattern.findall(text)[0] |
|
text_lines = [] |
|
for line in card_text.splitlines(): |
|
line = line.replace('{T}', '<i class="ms ms-tap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') |
|
line = line.replace('{UT}', '<i class="ms ms-untap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') |
|
line = line.replace('{E}', '<i class="ms ms-instant ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') |
|
line = re.sub(r"{(.*?)}", r'<i class="ms ms-\1 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), line) |
|
line = re.sub(r"ms-(.)/(.)", r'<i class="ms ms-\1\2 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), line) |
|
line = line.replace('(', '(<i>').replace(')', '</i>)') |
|
text_lines.append(f"<p>{line}</p>") |
|
template = template.replace("{card_text}", "\n".join(text_lines)) |
|
pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL) |
|
flavor_text = pattern.findall(text) |
|
if flavor_text: |
|
flavor_text = flavor_text[0] |
|
flavor_text_lines = [] |
|
for line in flavor_text.splitlines(): |
|
flavor_text_lines.append(f"<p>{line}</p>") |
|
template = template.replace("{flavor_text}", "<blockquote>" + "\n".join(flavor_text_lines) + "</blockquote>") |
|
else: |
|
template = template.replace("{flavor_text}", "") |
|
if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: |
|
template = template.replace("{text_size}", '16') |
|
template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>', |
|
'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;"></i>') |
|
else: |
|
template = template.replace("{text_size}", '18') |
|
pattern = re.compile('Power: (.*)') |
|
power = pattern.findall(text) |
|
if power: |
|
power = power[0] |
|
if not power: |
|
template = template.replace("{power_toughness}", "") |
|
pattern = re.compile('Toughness: (.*)') |
|
toughness = pattern.findall(text)[0] |
|
template = template.replace("{power_toughness}", f'<header class="powerToughness"><div><h2 style="font-family: \'Beleren\';font-size: 19px;">{power}/{toughness}</h2></div></header>') |
|
else: |
|
template = template.replace("{power_toughness}", "") |
|
pathlib.Path("test.html").write_text(template, encoding='utf-8') |
|
return template |
|
|
|
|
|
def get_savename(directory, name, extension): |
|
save_name = f"{name}.{extension}" |
|
i = 1 |
|
while os.path.exists(os.path.join(directory, save_name)): |
|
save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" |
|
i += 1 |
|
return save_name |
|
|
|
|
|
def html_to_png(card_name, html): |
|
save_name = get_savename('rendered_cards', card_name, 'png') |
|
print('CONVERTING HTML CARD TO PNG IMAGE') |
|
|
|
path = os.path.join('rendered_cards', save_name) |
|
try: |
|
css = ['./colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css', './colab-data-test/css/mtg_custom.css'] |
|
imgkit.from_string(html, path, {"xvfb": ""}, css=css) |
|
except: |
|
pass |
|
print('OPENING IMAGE FROM FILE') |
|
img = Image.open(path) |
|
print('CROPPING BACKGROUND') |
|
area = (0, 50, 400, 600) |
|
cropped_img = img.crop(area) |
|
cropped_img.resize((400, 550)) |
|
cropped_img.save(os.path.join(path)) |
|
print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') |
|
return cropped_img.convert('RGB') |
|
|
|
|
|
app_description = ( |
|
""" |
|
# Create your own Magic: The Gathering cards! |
|
Enter a name, click Submit, it may take up to 10 minutes to run on the free CPU only instance. |
|
""").strip() |
|
input_box = gr.Textbox(label="Enter a card name", placeholder="Firebolt") |
|
rendered_card = gr.Image(label="Card", type='pil', value="examples/card.png") |
|
output_text_box = gr.Textbox(label="Card Text", value=pathlib.Path("examples/text.txt").read_text('utf-8')) |
|
output_card_image = gr.Image(label="Card Image", type='pil', value="examples/image.png") |
|
output_card_html = gr.HTML(label="Card HTML", visible=False, show_label=False) |
|
x = gr.components.Textbox() |
|
iface = gr.Interface(title="MagicGen", theme="default", description=app_description, fn=run, inputs=[input_box], |
|
outputs=[rendered_card, output_text_box, output_card_image, output_card_html]) |
|
|
|
iface.launch() |
|
|