from contextlib import nullcontext import gradio as gr import torch from torch import autocast from diffusers import StableDiffusionPipeline import urllib, urllib.request import os from xml.etree import ElementTree import random import re from typing import List pokemon_types = ["Normal", "Water", "Fire", "Ice", "Psychic", "Rock", "Dark", "Electric", "Grass", "Fighting", "Poison", "Ground", "Flying", "Bug", "Ghost", "Dragon", "Steel", "Fairy" ] type_choices=["None", "Random"] type_choices.extend(pokemon_types) paper_name = None device = "cuda" if torch.cuda.is_available() else "cpu" context = autocast if device == "cuda" else nullcontext dtype = torch.float16 if device == "cuda" else torch.float32 pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-pokemon-diffusers", torch_dtype=dtype) pipe = pipe.to(device) # Sometimes the nsfw checker is confused by the Pokémon images, you can disable # it at your own risk here disable_safety = True if disable_safety: def null_safety(images, **kwargs): return images, False pipe.safety_checker = null_safety def infer(prompt, n_samples, steps, scale): with context("cuda"): images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images return images def get_paper_name(url: str): paper_id = os.path.basename(url) paper_id = paper_id.split(".pdf")[0] query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}" hdr = { "Content-Type" : "application/atom+xml" } req = urllib.request.Request(query_url, headers=hdr) response = urllib.request.urlopen(req) tree = ElementTree.fromstring(response.read().decode("utf-8")) paper_title = tree.find("{http://www.w3.org/2005/Atom}entry").find("{http://www.w3.org/2005/Atom}title").text paper_title = paper_title.replace("\n", "") paper_title = re.sub(' +', ' ', paper_title) return paper_title block = gr.Blocks() examples = [ [ "https://arxiv.org/abs/1706.03762", 2, 7.5, ], [ "https://arxiv.org/abs/1404.5997v2", 2, 7.5, ], [ "https://arxiv.org/abs/2010.11929", 2, 7.5, ], [ "https://arxiv.org/abs/1810.04805v2", 2, 7.5, ] ] with block: gr.HTML( """

Paper to Pokémon

Generate new Pokémon from an arXiv link. Just paste the link to the overview, the pdf or just give the ID of the paper. It will create a prompt with the paper title, which you can then modify as you like or submit as it is. For general better quality increase the step size. (This will also increase the processing time)

""" ) with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Link or ID for paper", show_label=False, max_lines=1, placeholder="Give arXiv link or ID for the paper", ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), ) poke_type = gr.Radio(choices=type_choices, value="None", label="Pokemon Type") prompt_ideas = gr.CheckboxGroup(choices=["as a bird", "with four legs", "with wings", "as a koala", "with a beak", "looking like a llama"], label="Additional prompt ideas") prompt_box = gr.Textbox(placeholder="Your prompt appears here", interactive=True, label="Prompt") gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") with gr.Row(elem_id="advanced-options"): samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=5) scale = gr.Slider( label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 ) ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False) ex.dataset.headers = [""] def resolve_poke_type(pok_type: str): if pok_type == "None": return "" elif pok_type == "Random": idx = random.randint(0,len(pokemon_types)-1) return pokemon_types[idx] else: return pok_type def update_prompt_link(new_link: str, pok_type: str, prompt_ideas: List[str]): global paper_name paper_name = get_paper_name(new_link) pok_type = resolve_poke_type(pok_type) prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}" return build_prompt_text(paper_name, pok_type, prompt_ideas) def update_prompt_type(paper_link: str, pok_type: str, prompt_ideas: List[str]): global paper_name if paper_name is None: paper_name = get_paper_name(paper_link) pok_type = resolve_poke_type(pok_type) return build_prompt_text(paper_name, pok_type, prompt_ideas) def build_prompt_text(paper_name, pok_type, add_ideas): prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}" prompt_text = f"""{prompt_text} {" ".join(add_ideas)}""" return prompt_text text.change(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) text.submit(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) poke_type.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) prompt_ideas.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) btn.click(infer, inputs=[prompt_box, samples, steps, scale], outputs=gallery) gr.HTML( """ """ ) block.launch()