Spaces:
Runtime error
Runtime error
| # pip install html2image | |
| import base64 | |
| import random | |
| from io import BytesIO | |
| from html2image import Html2Image | |
| import os | |
| import pathlib | |
| import re | |
| import gradio as gr | |
| import requests | |
| from PIL import Image | |
| from gradio_client import Client | |
| import torch | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Pipeline | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise Exception("HF_TOKEN environment variable is required to call remote API.") | |
| API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| client = Client("https://latent-consistency-super-fast-lcm-lora-sd1-5.hf.space") | |
| def init_speech_to_text_model() -> Pipeline: | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model_id = "distil-whisper/distil-medium.en" | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
| ) | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| max_new_tokens=128, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| whisper_pipe = init_speech_to_text_model() | |
| def query(payload: dict): | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| def generate_text(card_text: str, user_request: str) -> (str, str, str): | |
| # Prompt must apply the correct chat template for the model see: | |
| # https://huggingface.co/docs/transformers/main/en/chat_templating | |
| prompt = f"""<|system|> | |
| You create Magic the Gathering cards based on the user's request. | |
| # RULES | |
| - In your response always generate a new card. | |
| - Only generate one card, no other dialogue. | |
| - Surround card info in triple backticks (```). | |
| - Format the card text using headers like in the example below: | |
| ``` | |
| Name: Band of Brothers | |
| ManaCost: {{3}}{{W}}{{W}} | |
| Type: Creature — Phyrexian Human Soldier | |
| Rarity: rare | |
| Text: Vigilance | |
| {{W}}, {{T}}: Attach target creature you control to target creature. (Any number of attacking creatures with total power 5 or less can attack in a band. A band deals damage to that creature.) | |
| FlavorText: "This time we will be stronger." | |
| —Elder brotherhood blessing | |
| Power: 2 | |
| Toughness: 2 | |
| Color: ['W'] | |
| ```</s> | |
| <|user|> | |
| {user_request}</s> | |
| <|assistant|> | |
| """ | |
| if card_text and card_text != starting_text: | |
| prompt = f"""<|system|> | |
| You edit Magic the Gathering cards based on the user's request. | |
| # RULES | |
| - In your response always generate a new card. | |
| - Only generate one card, no other dialogue. | |
| - Surround card info in triple backticks (```). | |
| - Format the card text using headers like in the example below: | |
| ``` | |
| Name: Band of Brothers | |
| ManaCost: {{3}}{{W}}{{W}} | |
| Type: Creature — Phyrexian Human Soldier | |
| Rarity: rare | |
| Text: Vigilance | |
| {{W}}, {{T}}: Attach target creature you control to target creature. (Any number of attacking creatures with total power 5 or less can attack in a band. A band deals damage to that creature.) | |
| FlavorText: "This time we will be stronger." | |
| —Elder brotherhood blessing | |
| Power: 2 | |
| Toughness: 2 | |
| Color: ['W'] | |
| ```</s> | |
| <|user|> | |
| # CARD TO EDIT | |
| ``` | |
| {card_text} | |
| ``` | |
| # EDIT REQUEST | |
| {user_request}</s> | |
| <|assistant|> | |
| """ | |
| print(f"Calling API with prompt:\n{prompt}") | |
| params = {"max_new_tokens": 512} | |
| output = query({"inputs": prompt, "parameters": params}) | |
| if 'error' in output: | |
| print(f'Language model call failed: {output["error"]}') | |
| raise gr.Warning(f'Language model call failed: {output["error"]}') | |
| print(f'API RESPONSE SIZE: {len(output[0]["generated_text"])}') | |
| assistant_reply = output[0]["generated_text"].split('<|assistant|>')[1] | |
| print(f'ASSISTANT REPLY:\n{assistant_reply}') | |
| new_card_text = assistant_reply.split('```') | |
| if len(new_card_text) > 1: | |
| new_card_text = new_card_text[1].strip() + '\n' | |
| else: | |
| new_card_text = assistant_reply.split('\n\n') | |
| if len(new_card_text) < 2: | |
| return assistant_reply, card_text, None | |
| new_card_text = new_card_text[1].strip() + '\n' | |
| return assistant_reply, new_card_text, None | |
| def format_html(text, image_data): | |
| template = pathlib.Path("./card_template.html").read_text(encoding='utf-8') | |
| if "['U']" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#5a73ab"') | |
| elif "['W']" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') | |
| elif "['G']" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#325433"') | |
| elif "['B']" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') | |
| elif "['R']" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#c2401c"') | |
| elif "Type: Land" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#aa8c71"') | |
| elif "Type: Artifact" in text: | |
| template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') | |
| else: | |
| template = template.replace("{card_color}", 'style="background-color:#edd99d"') | |
| pattern = re.compile('Name: (.*)') | |
| name = pattern.findall(text)[0] | |
| template = template.replace("{name}", name) | |
| pattern = re.compile('Mana.?Cost: (.*)') | |
| mana_cost = pattern.findall(text)[0] | |
| if mana_cost == "None": | |
| template = template.replace("{mana_cost}", '<i class="ms ms-cost" style="visibility: hidden"></i>') | |
| else: | |
| symbols = [] | |
| for c in mana_cost: | |
| if c in {"{", "}"}: | |
| continue | |
| else: | |
| symbols.append(c.lower()) | |
| formatted_symbols = [] | |
| for s in symbols: | |
| formatted_symbols.append(f'<i class="ms ms-{s} ms-cost ms-shadow"></i>') | |
| template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) | |
| if not isinstance(image_data, (bytes, bytearray)): | |
| template = template.replace('{image_data}', f'{image_data}') | |
| else: | |
| template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') | |
| pattern = re.compile('Type: (.*)') | |
| card_type = pattern.findall(text)[0] | |
| template = template.replace("{card_type}", card_type) | |
| if len(card_type) > 30: | |
| template = template.replace("{type_size}", "16") | |
| else: | |
| template = template.replace("{type_size}", "18") | |
| pattern = re.compile('Rarity: (.*)') | |
| rarity = pattern.findall(text)[0] | |
| template = template.replace("{rarity}", f"ss-{rarity}") | |
| pattern = re.compile(r'^Text: (.*)\n\bFlavor.?Text|Power|Color\b', re.MULTILINE | re.DOTALL) | |
| card_text = pattern.findall(text)[0] | |
| text_lines = [] | |
| for line in card_text.splitlines(): | |
| line = line.replace('{T}', | |
| '<i class="ms ms-tap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') | |
| line = line.replace('{UT}', | |
| '<i class="ms ms-untap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') | |
| line = line.replace('{E}', | |
| '<i class="ms ms-instant ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') | |
| line = re.sub(r"{(.*?)}", | |
| r'<i class="ms ms-\1 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), | |
| line) | |
| line = re.sub(r"ms-(.)/(.)", | |
| r'<i class="ms ms-\1\2 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), | |
| line) | |
| line = line.replace('(', '(<i>').replace(')', '</i>)') | |
| text_lines.append(f"<p>{line}</p>") | |
| template = template.replace("{card_text}", "\n".join(text_lines)) | |
| pattern = re.compile(r'Flavor.?Text: (.*?)\n^.*$', re.MULTILINE | re.DOTALL) | |
| flavor_text = pattern.findall(text) | |
| if flavor_text: | |
| flavor_text = flavor_text[0] | |
| flavor_text_lines = [] | |
| for line in flavor_text.splitlines(): | |
| flavor_text_lines.append(f"<p>{line}</p>") | |
| template = template.replace("{flavor_text}", "<blockquote>" + "\n".join(flavor_text_lines) + "</blockquote>") | |
| else: | |
| template = template.replace("{flavor_text}", "") | |
| if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: | |
| template = template.replace("{text_size}", '16') | |
| template = template.replace( | |
| 'ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>', | |
| 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;"></i>') | |
| else: | |
| template = template.replace("{text_size}", '18') | |
| pattern = re.compile('Power: (.*)') | |
| power = pattern.findall(text) | |
| if power: | |
| power = power[0] | |
| if not power: | |
| template = template.replace("{power_toughness}", "") | |
| pattern = re.compile('Toughness: (.*)') | |
| toughness = pattern.findall(text)[0] | |
| template = template.replace("{power_toughness}", | |
| f'<header class="powerToughness"><div><h2 style="font-family: \'Beleren\';font-size: 19px;">{power}/{toughness}</h2></div></header>') | |
| else: | |
| template = template.replace("{power_toughness}", "") | |
| pathlib.Path("scratch.html").write_text(template, encoding='utf-8') | |
| return template | |
| def get_savename(directory, name, extension): | |
| save_name = f"{name}.{extension}" | |
| i = 1 | |
| while os.path.exists(os.path.join(directory, save_name)): | |
| save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" | |
| i += 1 | |
| return save_name | |
| def html_to_png(card_name, html): | |
| save_name = get_savename('rendered_cards', card_name, 'png') | |
| print('CONVERTING HTML CARD TO PNG IMAGE') | |
| path = os.path.join('rendered_cards', save_name) | |
| try: | |
| rendered_card_dir = 'rendered_cards' | |
| hti = Html2Image(output_path=rendered_card_dir) | |
| paths = hti.screenshot(html_str=html, | |
| css_file=['./css/mtg_custom.css', './css/mana.css', | |
| './css/keyrune.css'], | |
| save_as=save_name, size=(450, 600)) | |
| print(paths) | |
| path = paths[0] | |
| except: | |
| pass | |
| print('OPENING IMAGE FROM FILE') | |
| img = Image.open(path) | |
| print('CROPPING BACKGROUND') | |
| area = (0, 50, 400, 600) | |
| cropped_img = img.crop(area) | |
| cropped_img.resize((400, 550)) | |
| cropped_img.save(os.path.join(path)) | |
| print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') | |
| return cropped_img.convert('RGB') | |
| def get_initial_card(): | |
| return Image.open('SampleCard.png') | |
| def pil_to_base64(image): | |
| print('CONVERTING PIL IMAGE TO BASE64 STRING') | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()) | |
| print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') | |
| return img_str | |
| def generate_card(image: str, card_text: str): | |
| image_data = pil_to_base64(Image.open(image)) | |
| html = format_html(card_text, image_data) | |
| pattern = re.compile('Name: (.*)') | |
| name = pattern.findall(card_text)[0] | |
| card = html_to_png(name, html) | |
| return card | |
| def transcribe(audio: str) -> (str, str): | |
| result = whisper_pipe(audio) | |
| return result["text"], None | |
| starting_text = """Name: Wizards of the Coast | |
| ManaCost: {0} | |
| Type: Enchantment | |
| Rarity: mythic rare | |
| Text: At the beginning of your upkeep, reveal the top card of your library. If it's a card named "Magic: The Gathering", put it into your hand. Otherwise, put it into your graveyard. | |
| FlavorText: "We are the guardians of the multiverse, and we will protect it at all costs." | |
| Color: ['U']""" | |
| def generate_image(card_text: str): | |
| pattern = re.compile('Name: (.*)') | |
| name = pattern.findall(card_text)[0] | |
| pattern = re.compile('Type: (.*)') | |
| card_type = pattern.findall(card_text)[0] | |
| prompt = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" | |
| print(f'Calling image generation with prompt: {prompt}') | |
| try: | |
| result = client.predict( | |
| prompt, # str in 'parameter_5' Textbox component | |
| 0.3, # float (numeric value between 0.0 and 5) in 'Guidance' Slider component | |
| 4, # float (numeric value between 2 and 10) in 'Steps' Slider component | |
| random.randint(0, 12013012031030), | |
| # float (numeric value between 0 and 12013012031030) in 'Seed' Slider component | |
| api_name="/predict" | |
| ) | |
| print(result) | |
| return result | |
| except Exception as e: | |
| print(f'Failed to generate image from client: {e}') | |
| return 'placeholder.png' | |
| def add_hotkeys() -> str: | |
| return pathlib.Path("hotkeys.js").read_text() | |
| with gr.Blocks(title='MagicGen') as demo: | |
| gr.Markdown("# 🎴 MagicGenV2") | |
| gr.Markdown("## Generate and Edit Magic the Gathering Cards with a Chat Assistant") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| audio_in = gr.Microphone(label="Record a voice request (click or press ctrl + ` to start/stop)", | |
| type='filepath', elem_classes=["record-btn"]) | |
| prompt_in = gr.Textbox(label="Or type a text request and press Enter", interactive=True, | |
| placeholder="Need an idea? Try one of these:\n- Create a creature card named 'WiFi Elemental'\n- Make it an instant\n- Change the color") | |
| with gr.Accordion(label='🤖 Chat Assistant Response', open=False): | |
| bot_text = gr.TextArea(label='Response', interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_text = gr.TextArea(label="Card Text (Shift+Enter to submit)", value=starting_text) | |
| gen_image_button = gr.Button('🖼️ Generate Card Image') | |
| in_image = gr.Image(label="Card Image (400px x 550px)", type='filepath', value='placeholder.png') | |
| render_button = gr.Button('🎴 Render Card', variant="primary") | |
| gr.ClearButton([audio_in, prompt_in, in_text, in_image]) | |
| with gr.Column(): | |
| out_image = gr.Image(label="Rendered Card", value=get_initial_card()) | |
| transcribe_params = {'fn': transcribe, 'inputs': [audio_in], 'outputs': [prompt_in, audio_in]} | |
| generate_text_params = {'fn': generate_text, 'inputs': [in_text, prompt_in], | |
| 'outputs': [bot_text, in_text, audio_in]} | |
| generate_image_params = {'fn': generate_image, 'inputs': [in_text], 'outputs': [in_image]} | |
| generate_card_params = {'fn': generate_card, 'inputs': [in_image, in_text], 'outputs': [out_image]} | |
| # Shift + Enter to submit text in TextAreas | |
| audio_in.stop_recording(**transcribe_params).then(**generate_text_params).then(**generate_image_params).then( | |
| **generate_card_params) | |
| prompt_in.submit(**generate_text_params).then(**generate_image_params).then(**generate_card_params) | |
| in_text.submit(**generate_card_params) | |
| render_button.click(**generate_card_params) | |
| gen_image_button.click(**generate_image_params).then(**generate_card_params) | |
| demo.load(None, None, None, js=add_hotkeys()) | |
| if __name__ == "__main__": | |
| demo.queue().launch(favicon_path="favicon-96x96.png") | |