Johannes
update README and app instructions
a33222e
from contextlib import nullcontext
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import urllib, urllib.request
import os
from xml.etree import ElementTree
import random
import re
from typing import List
pokemon_types = ["Normal",
"Water",
"Fire",
"Ice",
"Psychic",
"Rock",
"Dark",
"Electric",
"Grass",
"Fighting",
"Poison",
"Ground",
"Flying",
"Bug",
"Ghost",
"Dragon",
"Steel",
"Fairy"
]
type_choices=["None", "Random"]
type_choices.extend(pokemon_types)
paper_name = None
device = "cuda" if torch.cuda.is_available() else "cpu"
context = autocast if device == "cuda" else nullcontext
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-pokemon-diffusers", torch_dtype=dtype)
pipe = pipe.to(device)
# Sometimes the nsfw checker is confused by the Pokémon images, you can disable
# it at your own risk here
disable_safety = True
if disable_safety:
def null_safety(images, **kwargs):
return images, False
pipe.safety_checker = null_safety
def infer(prompt, n_samples, steps, scale):
with context("cuda"):
images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images
return images
def get_paper_name(url: str):
paper_id = os.path.basename(url)
paper_id = paper_id.split(".pdf")[0]
query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}"
hdr = { "Content-Type" : "application/atom+xml" }
req = urllib.request.Request(query_url, headers=hdr)
response = urllib.request.urlopen(req)
tree = ElementTree.fromstring(response.read().decode("utf-8"))
paper_title = tree.find("{http://www.w3.org/2005/Atom}entry").find("{http://www.w3.org/2005/Atom}title").text
paper_title = paper_title.replace("\n", "")
paper_title = re.sub(' +', ' ', paper_title)
return paper_title
block = gr.Blocks()
examples = [
[
"https://arxiv.org/abs/1706.03762",
2,
7.5,
],
[
"https://arxiv.org/abs/1404.5997v2",
2,
7.5,
],
[
"https://arxiv.org/abs/2010.11929",
2,
7.5,
],
[
"https://arxiv.org/abs/1810.04805v2",
2,
7.5,
]
]
with block:
gr.HTML(
"""
<div style="text-align: center; max-width: 650px; margin: 50px auto;">
<div>
<h1 style="font-weight: 900; font-size: 3rem;">
Paper to Pokémon
</h1>
</div>
<p style="margin-bottom: 10px; margin-top: 30px; font-size: 94%">
Generate new Pokémon from an arXiv link. Just paste the link to the overview, the pdf or just give the ID of the paper.
It will create a prompt with the paper title, which you can then modify as you like or submit as it is.
For general better quality increase the step size. (This will also increase the processing time)
</p>
</div>
"""
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Link or ID for paper",
show_label=False,
max_lines=1,
placeholder="Give arXiv link or ID for the paper",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
poke_type = gr.Radio(choices=type_choices, value="None", label="Pokemon Type")
prompt_ideas = gr.CheckboxGroup(choices=["as a bird",
"with four legs",
"with wings",
"as a koala",
"with a beak",
"looking like a llama"],
label="Additional prompt ideas")
prompt_box = gr.Textbox(placeholder="Your prompt appears here", interactive=True, label="Prompt")
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
with gr.Row(elem_id="advanced-options"):
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=5)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False)
ex.dataset.headers = [""]
def resolve_poke_type(pok_type: str):
if pok_type == "None":
return ""
elif pok_type == "Random":
idx = random.randint(0,len(pokemon_types)-1)
return pokemon_types[idx]
else:
return pok_type
def update_prompt_link(new_link: str, pok_type: str, prompt_ideas: List[str]):
global paper_name
paper_name = get_paper_name(new_link)
pok_type = resolve_poke_type(pok_type)
prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
return build_prompt_text(paper_name, pok_type, prompt_ideas)
def update_prompt_type(paper_link: str, pok_type: str, prompt_ideas: List[str]):
global paper_name
if paper_name is None:
paper_name = get_paper_name(paper_link)
pok_type = resolve_poke_type(pok_type)
return build_prompt_text(paper_name, pok_type, prompt_ideas)
def build_prompt_text(paper_name, pok_type, add_ideas):
prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
prompt_text = f"""{prompt_text} {" ".join(add_ideas)}"""
return prompt_text
text.change(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
text.submit(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
poke_type.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
prompt_ideas.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
btn.click(infer, inputs=[prompt_box, samples, steps, scale], outputs=gallery)
gr.HTML(
"""
<div class="footer" style="text-align: center; max-width: 650px; margin: 50px auto;">
<p>Inspired by and cloned from the great <a href="https://huggingface.co/spaces/lambdalabs/text-to-pokemon">
Text-to-Pokémon</a> space by Lambda labs</p>
<p> Gradio Demo by johko</p>
</div>
"""
)
block.launch()