import base64 import pathlib import re import time from io import BytesIO import gradio as gr import numpy as np import torch from PIL import Image, ImageChops, ImageDraw from fastai.callback.core import Callback from fastai.learner import * from fastai.torch_core import TitledStr from html2image import Html2Image # from min_dalle import MinDalle from torch import tensor, Tensor, float16, float32 from torch.distributions import Transform from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler # These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see: # https://docs.fast.ai/learner.html#load_learner from transformers import GPT2TokenizerFast import os AUTH_TOKEN = os.environ.get('AUTH_TOKEN') # update requirements.txt with: # C:\Users\Grant\PycharmProjects\test_space\venv\Scripts\pip3.exe freeze > requirements.txt # Huggingface Spaces have 16GB RAM and 8 CPU cores # See https://huggingface.co/docs/hub/spaces-overview#hardware-resources pretrained_weights = 'gpt2' tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights) def tokenize(text): toks = tokenizer.tokenize(text) return tensor(tokenizer.convert_tokens_to_ids(toks)) class TransformersTokenizer(Transform): def __init__(self, tokenizer): self.tokenizer = tokenizer def encodes(self, x): return x if isinstance(x, Tensor) else tokenize(x) def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy())) class DropOutput(Callback): def after_pred(self): self.learn.pred = self.pred[0] # initialize only once # Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU # NOTE as of 2022-11-13 min-dalle is broken, switch to using a stable diffusion model for images # model = MinDalle( # models_root='./pretrained', # dtype=float32, # device='cpu', # is_mega=True, # is_reusable=True # ) # Download pipeline, but overwrite scheduler # Consider DPMSolverMultistepScheduler once added to diffusers from diffusers import EulerAncestralDiscreteScheduler scheduler = EulerAncestralDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler", use_auth_token=AUTH_TOKEN) pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float32, scheduler=scheduler, use_auth_token=AUTH_TOKEN) # pipe.enable_attention_slicing() # pipeline.to("cuda") def gen_image(prompt): prompt = f"{prompt}, fantasy painting by Greg Rutkowski" # See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py # Hugging Space faces seems to run out of memory if grads are not disabled # torch.set_grad_enabled(False) print(f'RUNNING gen_image with prompt: {prompt}') images = pipeline.text2img(prompt, width=256, height=256, num_inference_steps=20).images # images = model.generate_images( # text=prompt, # seed=-1, # grid_size=1, # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images # is_seamless=False, # temperature=1, # top_k=256, # supercondition_factor=16, # is_verbose=True # ) print('COMPLETED GENERATION') # images = images.to('cpu').numpy() # images = images.astype(np.uint8) # return Image.fromarray(images[0]) return images[0] gpu = False # init only once learner = load_learner('export.pkl', cpu=not gpu) # cpu=False uses GPU; make sure installed torch is GPU e.g. `pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116` def parse_monster_description(name, text): match = re.search(r"Description: (.*)", text) if not match: return f"{name} is a monster." description = match.group(1) print(description.split('.')[0]) return description.split('.')[0] def gen_monster_text(name): prompt = f"Name: {name}\r\n" print(f'GENERATING MONSTER TEXT with prompt: {prompt}') prompt_ids = tokenizer.encode(prompt) if gpu: inp = tensor(prompt_ids)[None].cuda() # Use .cuda() for torch GPU else: inp = tensor(prompt_ids)[None] preds = learner.model.generate(inp, max_length=512, num_beams=5, temperature=1.5, do_sample=True, repetition_penalty=1.2) result = tokenizer.decode(preds[0].cpu().numpy()) result = result.split('###')[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') print(f'GENERATING MONSTER COMPLETE') print(result) return result def extract_text_for_header(text, header): match = re.search(fr"{header}: (.*)", text) if match is None: return '' return match.group(1) def remove_section(html, html_class): match = re.search(f'
  • ', html) if match is not None: html = html.replace(match.group(0), '') return html def format_monster_card(monster_text, image_data): print('FORMATTING MONSTER TEXT') # see giffyglyph's monster maker https://giffyglyph.com/monstermaker/app/ # Different Formatting style examples and some json export formats card = pathlib.Path('monsterMakerTemplate.html').read_text() if not isinstance(image_data, (bytes, bytearray)): card = card.replace('{image_data}', f'{image_data}') else: card = card.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') name = extract_text_for_header(monster_text, 'Name') card = card.replace('{name}', name) monster_type = extract_text_for_header(monster_text, 'Type') card = card.replace('{monster_type}', monster_type) armor_class = extract_text_for_header(monster_text, 'Armor Class') card = card.replace('{armor_class}', armor_class) hit_points = extract_text_for_header(monster_text, 'Hit Points') card = card.replace('{hit_points}', hit_points) speed = extract_text_for_header(monster_text, 'Speed') card = card.replace('{speed}', speed) str_stat = extract_text_for_header(monster_text, 'STR') card = card.replace('{str_stat}', str_stat) dex_stat = extract_text_for_header(monster_text, 'DEX') card = card.replace('{dex_stat}', dex_stat) con_stat = extract_text_for_header(monster_text, 'CON') card = card.replace('{con_stat}', con_stat) int_stat = extract_text_for_header(monster_text, 'INT') card = card.replace('{int_stat}', int_stat) wis_stat = extract_text_for_header(monster_text, 'WIS') card = card.replace('{wis_stat}', wis_stat) cha_stat = extract_text_for_header(monster_text, 'CHA') card = card.replace('{cha_stat}', cha_stat) saving_throws = extract_text_for_header(monster_text, 'Saving Throws') card = card.replace('{saving_throws}', saving_throws) if not saving_throws: card = remove_section(card, 'monster-saves') skills = extract_text_for_header(monster_text, 'Skills') card = card.replace('{skills}', skills) if not skills: card = remove_section(card, 'monster-skills') damage_vulnerabilities = extract_text_for_header(monster_text, 'Damage Vulnerabilities') card = card.replace('{damage_vulnerabilities}', damage_vulnerabilities) if not damage_vulnerabilities: card = remove_section(card, 'monster-vulnerabilities') damage_resistances = extract_text_for_header(monster_text, 'Damage Resistances') card = card.replace('{damage_resistances}', damage_resistances) if not damage_resistances: card = remove_section(card, 'monster-resistances') damage_immunities = extract_text_for_header(monster_text, 'Damage Immunities') card = card.replace('{damage_immunities}', damage_immunities) if not damage_immunities: card = remove_section(card, 'monster-immunities') condition_immunities = extract_text_for_header(monster_text, 'Condition Immunities') card = card.replace('{condition_immunities}', condition_immunities) if not condition_immunities: card = remove_section(card, 'monster-conditions') senses = extract_text_for_header(monster_text, 'Senses') card = card.replace('{senses}', senses) if not senses: card = remove_section(card, 'monster-senses') languages = extract_text_for_header(monster_text, 'Languages') card = card.replace('{languages}', languages) if not languages: card = remove_section(card, 'monster-languages') challenge = extract_text_for_header(monster_text, 'Challenge') card = card.replace('{challenge}', challenge) if not challenge: card = remove_section(card, 'monster-challenge') description = extract_text_for_header(monster_text, 'Description') card = card.replace('{description}', description) match = re.search(r"Passives:\n([\w\W]*)", monster_text) if match is None: passives = '' else: passives = match.group(1) p = passives.split(':') if len(p) > 1: p = ":".join(p) p = p.split('\n') passives_data = '' for x in p: x = x.split(':') if len(x) > 1: trait = x[0] if trait == "Passives": continue if 'Action' in trait: break detail = ":".join(x[1:]) passives_data += f'

    {trait} {detail}

    ' card = card.replace('{passives}', passives_data) else: card = card.replace('{passives}', f'

    {passives}

    ') match = re.search(r"Actions:\n([\w\W]*)", monster_text) if match is None: actions = '' else: actions = match.group(1) a = actions.split(':') if len(a) > 1: a = ":".join(a) a = a.split('\n') actions_data = '' for x in a: x = x.split(':') if len(x) > 1: action = x[0] if action == "Actions": continue if 'Passive' in action: break detail = ":".join(x[1:]) actions_data += f'

    {action} {detail}

    ' card = card.replace('{actions}', actions_data) else: card = card.replace('{actions}', f'

    {actions}

    ') # TODO: Legendary actions, reactions, make column count for format an option (1 or 2 column layout) card = card.replace('Melee or Ranged Weapon Attack:', 'Melee or Ranged Weapon Attack:') card = card.replace('Melee Weapon Attack:', 'Melee Weapon Attack:') card = card.replace('Ranged Weapon Attack:', 'Ranged Weapon Attack:') card = card.replace('Hit:', 'Hit:') print('FORMATTING MONSTER TEXT COMPLETE') return card def pil_to_base64(image): print('CONVERTING PIL IMAGE TO BASE64 STRING') buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') return img_str hti = Html2Image(output_path='rendered_cards') def trim(im, border): bg = Image.new(im.mode, im.size, border) diff = ImageChops.difference(im, bg) bbox = diff.getbbox() if bbox: return im.crop(bbox) def crop_background(image): white = (255, 255, 255) ImageDraw.floodfill(image, (image.size[0] - 1, 0), white, thresh=50) image = trim(image, white) return image def html_to_png(html): print('CONVERTING HTML CARD TO PNG IMAGE') paths = hti.screenshot(html_str=html, css_file="monstermaker.css", save_as="test.png", size=(800, 1440)) path = paths[0] print('OPENING IMAGE FROM FILE') img = Image.open(path).convert("RGB") print('CROPPING BACKGROUND') img = crop_background(img) print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') return img def run(name: str) -> (Image, str, Image, str): start = time.time() print(f'BEGINNING RUN FOR {name}') if not name: placeholder_image = Image.new(mode="RGB", size=(256, 256)) return placeholder_image, 'No name provided; enter a name and try again', placeholder_image, '' text = gen_monster_text(name) description = parse_monster_description(name, text) pil = gen_image(description) image_data = pil_to_base64(pil) card_html = format_monster_card(text, image_data) card_image = html_to_png(card_html) end = time.time() print(f'RUN COMPLETED IN {int(end - start)} seconds') return card_image, text, pil, card_html app_description = ( """ # Create your own D&D monster! Enter a name, click Submit, and wait for about 4 minutes to see the result. """).strip() input_box = gr.Textbox(label="Enter a monster name", placeholder="Jabberwock") output_monster_card = gr.Image(label="Monster Card", type='pil', value="examples/jabberwock_card.png") output_text_box = gr.Textbox(label="Monster Text", value=pathlib.Path("examples/jabberwock.txt").read_text('utf-8')) output_monster_image = gr.Image(label="Monster Image", type='pil', value="examples/jabberwock.png") output_monster_html = gr.HTML(label="Monster HTML", visible=False, show_label=False) x = gr.components.Textbox() iface = gr.Interface(title="MonsterGen", theme="default", description=app_description, fn=run, inputs=[input_box], outputs=[output_monster_card, output_text_box, output_monster_image, output_monster_html]) iface.launch() # TODO: Add examples, larger language model?, document process, log silences, "Passives" => "Traits", log timestamps # Fine tune dalle-mini? https://blog.paperspace.com/dalle-mini/ # API works, assuming query takes no longer than 30 seconds (504 gateway timeout) # Looks like API page improvements are in progress: https://github.com/gradio-app/gradio/issues/1325 # Example code below: # import requests # r = requests.post(url='https://hf.space/embed/gstaff/test_space/+/api/predict', json={"data": [""]}, # timeout=None) # print(r.json()) # Looks like Huggingface uses the queue push api, then polls for status: # fetch("https://hf.space/embed/gstaff/test_space/api/queue/push/", { # "headers": { # "accept": "*/*", # "accept-language": "en-US,en;q=0.9", # "content-type": "application/json", # "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"", # "sec-ch-ua-mobile": "?0", # "sec-ch-ua-platform": "\"Windows\"", # "sec-fetch-dest": "empty", # "sec-fetch-mode": "cors", # "sec-fetch-site": "same-origin" # }, # "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light", # "referrerPolicy": "strict-origin-when-cross-origin", # "body": "{\"fn_index\":0,\"data\":[\"Jabberwock\"],\"action\":\"predict\",\"session_hash\":\"v9ehgfho3p\"}", # "method": "POST", # "mode": "cors", # "credentials": "omit" # }); # fetch("https://hf.space/embed/gstaff/test_space/api/queue/status/", { # "headers": { # "accept": "*/*", # "accept-language": "en-US,en;q=0.9", # "content-type": "application/json", # "sec-ch-ua": "\".Not/A)Brand\";v=\"99\", \"Google Chrome\";v=\"103\", \"Chromium\";v=\"103\"", # "sec-ch-ua-mobile": "?0", # "sec-ch-ua-platform": "\"Windows\"", # "sec-fetch-dest": "empty", # "sec-fetch-mode": "cors", # "sec-fetch-site": "same-origin" # }, # "referrer": "https://hf.space/embed/gstaff/test_space/+?__theme=light", # "referrerPolicy": "strict-origin-when-cross-origin", # "body": "{\"hash\":\"09f5369a7a414169aa48948bad5fd93d\"}", # "method": "POST", # "mode": "cors", # "credentials": "omit" # });