import gradio as gr import pathlib import base64 import re import time from io import BytesIO import imgkit import os from PIL import Image from fastai.callback.core import Callback from fastai.learner import * from fastai.torch_core import TitledStr from torch import tensor, Tensor from torch.distributions import Transform import random # These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see: # https://docs.fast.ai/learner.html#load_learner from transformers import GPT2TokenizerFast import torch from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler gpu = False AUTH_TOKEN = os.environ.get('AUTH_TOKEN') if gpu: pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = scheduler pipeline.to("cuda") else: pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", use_auth_token=AUTH_TOKEN) scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = scheduler # Huggingface Spaces have 16GB RAM and 8 CPU cores # See https://huggingface.co/docs/hub/spaces-overview#hardware-resources pretrained_weights = 'gpt2' tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights) def tokenize(text): toks = tokenizer.tokenize(text) return tensor(tokenizer.convert_tokens_to_ids(toks)) class TransformersTokenizer(Transform): def __init__(self, tokenizer): self.tokenizer = tokenizer def encodes(self, x): return x if isinstance(x, Tensor) else tokenize(x) def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy())) class DropOutput(Callback): def after_pred(self): self.learn.pred = self.pred[0] def gen_card_text(name): if name == '': prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}" else: prompt = f"Name: {name}\r\n" print(f'GENERATING CARD TEXT with prompt: {prompt}') prompt_ids = tokenizer.encode(prompt) if gpu: inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU else: inp = tensor(prompt_ids)[None] preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True, repetition_penalty=1.2) result = tokenizer.decode(preds[0].cpu().numpy()) result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') print(f'GENERATING CARD COMPLETE') print(result) if name == '': pattern = re.compile('Name: (.*)') name = pattern.findall(result)[0] return name, result # init only once learner = load_learner('./colab-data-test/export.pkl', cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116` pathlib.Path('card_data').mkdir(parents=True, exist_ok=True) pathlib.Path('card_images').mkdir(parents=True, exist_ok=True) pathlib.Path('card_html').mkdir(parents=True, exist_ok=True) pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True) def run(name): start = time.time() print(f'BEGINNING RUN FOR {name}') name, text = gen_card_text(name) save_name = get_savename('card_data', name, 'txt') pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" print(f"GENERATING IMAGE FOR {prompt_template}") # Regarding sizing see https://huggingface.co/blog/stable_diffusion#:~:text=When%20choosing%20image%20sizes%2C%20we%20advise%20the%20following%3A images = pipeline.text2img(prompt_template, width=512, height=368, num_inference_steps=20).images card_image = None for image in images: save_name = get_savename('card_images', name, 'png') image.save(f"card_images/{save_name}") card_image = image image_data = pil_to_base64(card_image) html = format_html(text, image_data) save_name = get_savename('card_html', name, 'html') pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8') rendered = html_to_png(name, html) end = time.time() print(f'RUN COMPLETED IN {int(end - start)} seconds') return rendered, text, card_image, html def pil_to_base64(image): print('CONVERTING PIL IMAGE TO BASE64 STRING') buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') return img_str def format_html(text, image_data): template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8') if "['U']" in text: template = template.replace("{card_color}", 'style="background-color:#5a73ab"') elif "['W']" in text: template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') elif "['G']" in text: template = template.replace("{card_color}", 'style="background-color:#325433"') elif "['B']" in text: template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') elif "['R']" in text: template = template.replace("{card_color}", 'style="background-color:#c2401c"') elif "Type: Land" in text: template = template.replace("{card_color}", 'style="background-color:#aa8c71"') elif "Type: Artifact" in text: template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') else: template = template.replace("{card_color}", 'style="background-color:#edd99d"') pattern = re.compile('Name: (.*)') name = pattern.findall(text)[0] template = template.replace("{name}", name) pattern = re.compile('ManaCost: (.*)') mana_cost = pattern.findall(text)[0] if mana_cost == "None": template = template.replace("{mana_cost}", '') else: symbols = [] for c in mana_cost: if c in {"{", "}"}: continue else: symbols.append(c.lower()) formatted_symbols = [] for s in symbols: formatted_symbols.append(f'') template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) if not isinstance(image_data, (bytes, bytearray)): template = template.replace('{image_data}', f'{image_data}') else: template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') pattern = re.compile('Type: (.*)') card_type = pattern.findall(text)[0] template = template.replace("{card_type}", card_type) if len(card_type) > 30: template = template.replace("{type_size}", "16") else: template = template.replace("{type_size}", "18") pattern = re.compile('Rarity: (.*)') rarity = pattern.findall(text)[0] template = template.replace("{rarity}", f"ss-{rarity}") pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL) card_text = pattern.findall(text)[0] text_lines = [] for line in card_text.splitlines(): line = line.replace('{T}', '') line = line.replace('{UT}', '') line = line.replace('{E}', '') line = re.sub(r"{(.*?)}", r''.lower(), line) line = re.sub(r"ms-(.)/(.)", r''.lower(), line) line = line.replace('(', '(').replace(')', ')') text_lines.append(f"
{line}
") template = template.replace("{card_text}", "\n".join(text_lines)) pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL) flavor_text = pattern.findall(text) if flavor_text: flavor_text = flavor_text[0] flavor_text_lines = [] for line in flavor_text.splitlines(): flavor_text_lines.append(f"{line}
") template = template.replace("{flavor_text}", "" + "\n".join(flavor_text_lines) + "") else: template = template.replace("{flavor_text}", "") if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: template = template.replace("{text_size}", '16') template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;">', 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;">') else: template = template.replace("{text_size}", '18') pattern = re.compile('Power: (.*)') power = pattern.findall(text) if power: power = power[0] if not power: template = template.replace("{power_toughness}", "") pattern = re.compile('Toughness: (.*)') toughness = pattern.findall(text)[0] template = template.replace("{power_toughness}", f'