import gradio as gr import pathlib import base64 import re import time from io import BytesIO import imgkit import os from PIL import Image from fastai.callback.core import Callback from fastai.learner import * from fastai.torch_core import TitledStr from torch import tensor, Tensor from torch.distributions import Transform import random # These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see: # https://docs.fast.ai/learner.html#load_learner from transformers import GPT2TokenizerFast import torch from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler gpu = False AUTH_TOKEN = os.environ.get('AUTH_TOKEN') if gpu: pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = scheduler pipeline.to("cuda") else: pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = scheduler # Huggingface Spaces have 16GB RAM and 8 CPU cores # See https://huggingface.co/docs/hub/spaces-overview#hardware-resources pretrained_weights = 'gpt2' tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights) def tokenize(text): toks = tokenizer.tokenize(text) return tensor(tokenizer.convert_tokens_to_ids(toks)) class TransformersTokenizer(Transform): def __init__(self, tokenizer): self.tokenizer = tokenizer def encodes(self, x): return x if isinstance(x, Tensor) else tokenize(x) def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy())) class DropOutput(Callback): def after_pred(self): self.learn.pred = self.pred[0] def gen_card_text(name): if name == '': prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}" else: prompt = f"Name: {name}\r\n" print(f'GENERATING CARD TEXT with prompt: {prompt}') prompt_ids = tokenizer.encode(prompt) if gpu: inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU else: inp = tensor(prompt_ids)[None] preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True, repetition_penalty=1.2) result = tokenizer.decode(preds[0].cpu().numpy()) result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') print(f'GENERATING CARD COMPLETE') print(result) if name == '': pattern = re.compile('Name: (.*)') name = pattern.findall(result)[0] return name, result # init only once learner = load_learner('./colab-data-test/export.pkl', cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116` pathlib.Path('card_data').mkdir(parents=True, exist_ok=True) pathlib.Path('card_images').mkdir(parents=True, exist_ok=True) pathlib.Path('card_html').mkdir(parents=True, exist_ok=True) pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True) def run(name): start = time.time() print(f'BEGINNING RUN FOR {name}') name, text = gen_card_text(name) save_name = get_savename('card_data', name, 'txt') pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" print(f"GENERATING IMAGE FOR {prompt_template}") # Regarding sizing see https://huggingface.co/blog/stable_diffusion#:~:text=When%20choosing%20image%20sizes%2C%20we%20advise%20the%20following%3A images = pipeline.text2img(prompt_template, width=512, height=368, num_inference_steps=20).images card_image = None for image in images: save_name = get_savename('card_images', name, 'png') image.save(f"card_images/{save_name}") card_image = image image_data = pil_to_base64(card_image) html = format_html(text, image_data) save_name = get_savename('card_html', name, 'html') pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8') rendered = html_to_png(name, html) end = time.time() print(f'RUN COMPLETED IN {int(end - start)} seconds') return rendered, text, card_image, html def pil_to_base64(image): print('CONVERTING PIL IMAGE TO BASE64 STRING') buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') return img_str def format_html(text, image_data): template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8') if "['U']" in text: template = template.replace("{card_color}", 'style="background-color:#5a73ab"') elif "['W']" in text: template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') elif "['G']" in text: template = template.replace("{card_color}", 'style="background-color:#325433"') elif "['B']" in text: template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') elif "['R']" in text: template = template.replace("{card_color}", 'style="background-color:#c2401c"') elif "Type: Land" in text: template = template.replace("{card_color}", 'style="background-color:#aa8c71"') elif "Type: Artifact" in text: template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') else: template = template.replace("{card_color}", 'style="background-color:#edd99d"') pattern = re.compile('Name: (.*)') name = pattern.findall(text)[0] template = template.replace("{name}", name) pattern = re.compile('ManaCost: (.*)') mana_cost = pattern.findall(text)[0] if mana_cost == "None": template = template.replace("{mana_cost}", '') else: symbols = [] for c in mana_cost: if c in {"{", "}"}: continue else: symbols.append(c.lower()) formatted_symbols = [] for s in symbols: formatted_symbols.append(f'') template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) if not isinstance(image_data, (bytes, bytearray)): template = template.replace('{image_data}', f'{image_data}') else: template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] template = template.replace("{card_type}", card_type) if len(card_type) > 30: template = template.replace("{type_size}", "16") else: template = template.replace("{type_size}", "18") pattern = re.compile('Rarity: (.*)') rarity = pattern.findall(text)[0] template = template.replace("{rarity}", f"ss-{rarity}") pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL) card_text = pattern.findall(text)[0] text_lines = [] for line in card_text.splitlines(): line = line.replace('{T}', '') line = line.replace('{UT}', '') line = line.replace('{E}', '') line = re.sub(r"{(.*?)}", r''.lower(), line) line = re.sub(r"ms-(.)/(.)", r''.lower(), line) line = line.replace('(', '(').replace(')', ')') text_lines.append(f"

{line}

") template = template.replace("{card_text}", "\n".join(text_lines)) pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL) flavor_text = pattern.findall(text) if flavor_text: flavor_text = flavor_text[0] flavor_text_lines = [] for line in flavor_text.splitlines(): flavor_text_lines.append(f"

{line}

") template = template.replace("{flavor_text}", "
" + "\n".join(flavor_text_lines) + "
") else: template = template.replace("{flavor_text}", "") if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: template = template.replace("{text_size}", '16') template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;">', 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;">') else: template = template.replace("{text_size}", '18') pattern = re.compile('Power: (.*)') power = pattern.findall(text) if power: power = power[0] if not power: template = template.replace("{power_toughness}", "") pattern = re.compile('Toughness: (.*)') toughness = pattern.findall(text)[0] template = template.replace("{power_toughness}", f'

{power}/{toughness}

') else: template = template.replace("{power_toughness}", "") pathlib.Path("test.html").write_text(template, encoding='utf-8') return template def get_savename(directory, name, extension): save_name = f"{name}.{extension}" i = 1 while os.path.exists(os.path.join(directory, save_name)): save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" i += 1 return save_name def html_to_png(card_name, html): save_name = get_savename('rendered_cards', card_name, 'png') print('CONVERTING HTML CARD TO PNG IMAGE') path = os.path.join('rendered_cards', save_name) try: css = ['./colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css', './colab-data-test/css/mtg_custom.css'] imgkit.from_string(html, path, {"xvfb": ""}, css=css) except: pass print('OPENING IMAGE FROM FILE') img = Image.open(path) print('CROPPING BACKGROUND') area = (0, 50, 400, 600) cropped_img = img.crop(area) cropped_img.resize((400, 550)) cropped_img.save(os.path.join(path)) print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') return cropped_img.convert('RGB') app_description = ( """ # Create your own Magic: The Gathering cards! Enter a name, click Submit, it may take up to 10 minutes to run on the free CPU only instance. """).strip() input_box = gr.Textbox(label="Enter a card name", placeholder="Firebolt") rendered_card = gr.Image(label="Card", type='pil', value="examples/card.png") output_text_box = gr.Textbox(label="Card Text", value=pathlib.Path("examples/text.txt").read_text('utf-8')) output_card_image = gr.Image(label="Card Image", type='pil', value="examples/image.png") output_card_html = gr.HTML(label="Card HTML", visible=False, show_label=False) x = gr.components.Textbox() iface = gr.Interface(title="MagicGen", theme="default", description=app_description, fn=run, inputs=[input_box], outputs=[rendered_card, output_text_box, output_card_image, output_card_html]) iface.launch()