import base64 import re import os import pathlib import random import time from io import BytesIO from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler import gradio as gr import imgkit from PIL import Image import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast, pipeline gpu = False AUTH_TOKEN = os.environ.get('AUTH_TOKEN') BASE_MODEL = "gpt2" MERGED_MODEL = "gpt2-magic-card" if gpu: image_pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, revision="fp16", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(image_pipeline.scheduler.config) image_pipeline.scheduler = scheduler image_pipeline.to("cuda") else: image_pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(image_pipeline.scheduler.config) image_pipeline.scheduler = scheduler # Huggingface Spaces have 16GB RAM and 8 CPU cores # See https://huggingface.co/docs/hub/spaces-overview#hardware-resources model = GPT2LMHeadModel.from_pretrained(MERGED_MODEL) tokenizer = GPT2TokenizerFast.from_pretrained(BASE_MODEL) END_TOKEN = '###' eos_id = tokenizer.encode(END_TOKEN) text_pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer) def gen_card_text(name): if name == '': prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}" else: prompt = f"Name: {name}\n" print(f'GENERATING CARD TEXT with prompt: {prompt}') output = text_pipeline(prompt, max_length=512, num_return_sequences=1, num_beams=5, temperature=1.5, do_sample=True, repetition_penalty=1.2, eos_token_id=eos_id) result = output[0]['generated_text'].split("###")[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') print(f'GENERATING CARD COMPLETE') print(result) if name == '': pattern = re.compile('Name: (.*)') name = pattern.findall(result)[0] return name, result pathlib.Path('card_data').mkdir(parents=True, exist_ok=True) pathlib.Path('card_images').mkdir(parents=True, exist_ok=True) pathlib.Path('card_html').mkdir(parents=True, exist_ok=True) pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True) def run(name): start = time.time() print(f'BEGINNING RUN FOR {name}') name, text = gen_card_text(name) save_name = get_savename('card_data', name, 'txt') pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" print(f"GENERATING IMAGE FOR {prompt_template}") # Regarding sizing see https://huggingface.co/blog/stable_diffusion#:~:text=When%20choosing%20image%20sizes%2C%20we%20advise%20the%20following%3A images = image_pipeline(prompt_template, width=512, height=368, num_inference_steps=20).images card_image = None for image in images: save_name = get_savename('card_images', name, 'png') image.save(f"card_images/{save_name}") card_image = image image_data = pil_to_base64(card_image) html = format_html(text, image_data) save_name = get_savename('card_html', name, 'html') pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8') rendered = html_to_png(name, html) end = time.time() print(f'RUN COMPLETED IN {int(end - start)} seconds') return rendered, text, card_image, html def pil_to_base64(image): print('CONVERTING PIL IMAGE TO BASE64 STRING') buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') return img_str def format_html(text, image_data): template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8') if "['U']" in text: template = template.replace("{card_color}", 'style="background-color:#5a73ab"') elif "['W']" in text: template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') elif "['G']" in text: template = template.replace("{card_color}", 'style="background-color:#325433"') elif "['B']" in text: template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') elif "['R']" in text: template = template.replace("{card_color}", 'style="background-color:#c2401c"') elif "Type: Land" in text: template = template.replace("{card_color}", 'style="background-color:#aa8c71"') elif "Type: Artifact" in text: template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') else: template = template.replace("{card_color}", 'style="background-color:#edd99d"') pattern = re.compile('Name: (.*)') name = pattern.findall(text)[0] template = template.replace("{name}", name) pattern = re.compile('ManaCost: (.*)') mana_cost = pattern.findall(text)[0] if mana_cost == "None": template = template.replace("{mana_cost}", '') else: symbols = [] for c in mana_cost: if c in {"{", "}"}: continue else: symbols.append(c.lower()) formatted_symbols = [] for s in symbols: formatted_symbols.append(f'') template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) if not isinstance(image_data, (bytes, bytearray)): template = template.replace('{image_data}', f'{image_data}') else: template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] template = template.replace("{card_type}", card_type) if len(card_type) > 30: template = template.replace("{type_size}", "16") else: template = template.replace("{type_size}", "18") pattern = re.compile('Rarity: (.*)') rarity = pattern.findall(text)[0] template = template.replace("{rarity}", f"ss-{rarity}") pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL) card_text = pattern.findall(text)[0] text_lines = [] for line in card_text.splitlines(): line = line.replace('{T}', '') line = line.replace('{UT}', '') line = line.replace('{E}', '') line = re.sub(r"{(.*?)}", r''.lower(), line) line = re.sub(r"ms-(.)/(.)", r''.lower(), line) line = line.replace('(', '(').replace(')', ')') text_lines.append(f"

{line}

") template = template.replace("{card_text}", "\n".join(text_lines)) pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL) flavor_text = pattern.findall(text) if flavor_text: flavor_text = flavor_text[0] flavor_text_lines = [] for line in flavor_text.splitlines(): flavor_text_lines.append(f"

{line}

") template = template.replace("{flavor_text}", "
" + "\n".join(flavor_text_lines) + "
") else: template = template.replace("{flavor_text}", "") if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: template = template.replace("{text_size}", '16') template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;">', 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;">') else: template = template.replace("{text_size}", '18') pattern = re.compile('Power: (.*)') power = pattern.findall(text) if power: power = power[0] if not power: template = template.replace("{power_toughness}", "") pattern = re.compile('Toughness: (.*)') toughness = pattern.findall(text)[0] template = template.replace("{power_toughness}", f'

{power}/{toughness}

') else: template = template.replace("{power_toughness}", "") pathlib.Path("test.html").write_text(template, encoding='utf-8') return template def get_savename(directory, name, extension): save_name = f"{name}.{extension}" i = 1 while os.path.exists(os.path.join(directory, save_name)): save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" i += 1 return save_name def html_to_png(card_name, html): save_name = get_savename('rendered_cards', card_name, 'png') print('CONVERTING HTML CARD TO PNG IMAGE') path = os.path.join('rendered_cards', save_name) try: css = ['./colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css', './colab-data-test/css/mtg_custom.css'] imgkit.from_string(html, path, {"xvfb": ""}, css=css) except: try: # For Windows local, requires 'html2image' package from pip. from html2image import Html2Image rendered_card_dir = 'rendered_cards' hti = Html2Image(output_path=rendered_card_dir) paths = hti.screenshot(html_str=html, css_file=['./colab-data-test/css/mtg_custom.css', './colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css'], save_as=save_name, size=(450, 600)) print(paths) path = paths[0] except: pass print('OPENING IMAGE FROM FILE') img = Image.open(path) print('CROPPING BACKGROUND') area = (0, 50, 400, 600) cropped_img = img.crop(area) cropped_img.resize((400, 550)) cropped_img.save(os.path.join(path)) print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') return cropped_img.convert('RGB') app_description = ( """ # Create your own Magic: The Gathering cards! Enter a name, click Submit, it may take up to 10 minutes to run on the free CPU only instance. """).strip() input_box = gr.Textbox(label="Enter a card name", placeholder="Firebolt") rendered_card = gr.Image(label="Card", type='pil', value="examples/card.png") output_text_box = gr.Textbox(label="Card Text", value=pathlib.Path("examples/text.txt").read_text('utf-8')) output_card_image = gr.Image(label="Card Image", type='pil', value="examples/image.png") output_card_html = gr.HTML(label="Card HTML", visible=False, show_label=False) x = gr.components.Textbox() iface = gr.Interface(title="MagicGen", theme="default", description=app_description, fn=run, inputs=[input_box], outputs=[rendered_card, output_text_box, output_card_image, output_card_html]) iface.launch()