File size: 7,773 Bytes
e95d2f0
 
 
 
 
9c281ac
 
 
ea23bc4
 
ed1dd9b
ea23bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e95d2f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea23bc4
e95d2f0
ea23bc4
e95d2f0
 
 
9c281ac
 
 
 
ea23bc4
9c281ac
 
ea23bc4
 
 
 
9c281ac
 
 
 
 
e95d2f0
 
 
ea23bc4
 
 
 
 
 
e95d2f0
 
 
ea23bc4
 
 
 
 
 
 
 
 
 
e95d2f0
 
 
 
 
9c281ac
e95d2f0
 
9c281ac
e95d2f0
 
9c281ac
72ed919
a33222e
72ed919
a33222e
 
e95d2f0
 
 
 
 
 
 
 
ea23bc4
e95d2f0
 
72ed919
e95d2f0
 
 
 
 
 
 
 
 
ea23bc4
 
 
 
 
 
 
 
 
 
 
e95d2f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea23bc4
 
 
 
 
 
 
 
 
ed1dd9b
ea23bc4
 
 
 
 
 
 
 
ed1dd9b
ea23bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e95d2f0
 
9c281ac
 
 
 
e95d2f0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from contextlib import nullcontext
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import urllib, urllib.request
import os
from xml.etree import ElementTree
import random
import re
from typing import List


pokemon_types = ["Normal",
                 "Water",
                 "Fire",
                 "Ice",
                 "Psychic",
                 "Rock",
                 "Dark",
                 "Electric",
                 "Grass",
                 "Fighting",
                 "Poison",
                 "Ground",
                 "Flying",
                 "Bug",
                 "Ghost",
                 "Dragon",
                 "Steel",
                 "Fairy"
                 ]

type_choices=["None", "Random"]
type_choices.extend(pokemon_types)

paper_name = None

device = "cuda" if torch.cuda.is_available() else "cpu"
context = autocast if device == "cuda" else nullcontext
dtype = torch.float16 if device == "cuda" else torch.float32

pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-pokemon-diffusers", torch_dtype=dtype)
pipe = pipe.to(device)


# Sometimes the nsfw checker is confused by the Pokémon images, you can disable
# it at your own risk here
disable_safety = True

if disable_safety:
  def null_safety(images, **kwargs):
      return images, False
  pipe.safety_checker = null_safety


def infer(prompt, n_samples, steps, scale):    
    with context("cuda"):
        images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images

    return images

def get_paper_name(url: str):
    paper_id = os.path.basename(url)
    paper_id = paper_id.split(".pdf")[0]
    query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}"
    hdr = { "Content-Type" : "application/atom+xml" }
    req = urllib.request.Request(query_url, headers=hdr)
    response = urllib.request.urlopen(req)
    tree = ElementTree.fromstring(response.read().decode("utf-8"))
    paper_title = tree.find("{http://www.w3.org/2005/Atom}entry").find("{http://www.w3.org/2005/Atom}title").text
    paper_title = paper_title.replace("\n", "")
    paper_title = re.sub(' +', ' ', paper_title)
    return paper_title
    


block = gr.Blocks()

examples = [
    [
        "https://arxiv.org/abs/1706.03762",
        2,
        7.5,
    ],
    [
        "https://arxiv.org/abs/1404.5997v2",
        2,
        7.5,
    ],
    [
        "https://arxiv.org/abs/2010.11929",
        2,
        7.5,
    ],
    [
        "https://arxiv.org/abs/1810.04805v2",
        2,
        7.5,
    ]
]

with block:
    gr.HTML(
        """
            <div style="text-align: center; max-width: 650px; margin: 50px auto;">
              <div>
                <h1 style="font-weight: 900; font-size: 3rem;">
                  Paper to Pokémon
                </h1>
              </div>
              <p style="margin-bottom: 10px; margin-top: 30px; font-size: 94%">
              Generate new Pokémon from an arXiv link. Just paste the link to the overview, the pdf or just give the ID of the paper. 
              
              It will create a prompt with the paper title, which you can then modify as you like or submit as it is.
              
              For general better quality increase the step size. (This will also increase the processing time)
              </p>
            </div>
        """
    )
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Link or ID for paper",
                    show_label=False,
                    max_lines=1,
                    placeholder="Give arXiv link or ID for the paper",
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Generate image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )
        poke_type = gr.Radio(choices=type_choices, value="None", label="Pokemon Type")
        
        prompt_ideas = gr.CheckboxGroup(choices=["as a bird", 
                                                 "with four legs", 
                                                 "with wings", 
                                                 "as a koala", 
                                                 "with a beak", 
                                                 "looking like a llama"],
                                        label="Additional prompt ideas")
        
        prompt_box = gr.Textbox(placeholder="Your prompt appears here", interactive=True, label="Prompt")

        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")


        with gr.Row(elem_id="advanced-options"):
            samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
            steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=5)
            scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
            )


        ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False)
        ex.dataset.headers = [""]

        def resolve_poke_type(pok_type: str):
            if pok_type == "None":
                return ""
            elif pok_type == "Random":
                idx = random.randint(0,len(pokemon_types)-1)
                return pokemon_types[idx]
            else:
                return pok_type
        
        def update_prompt_link(new_link: str, pok_type: str, prompt_ideas: List[str]):
            global paper_name
            paper_name = get_paper_name(new_link)
            pok_type = resolve_poke_type(pok_type)
            
            prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
            
            return build_prompt_text(paper_name, pok_type, prompt_ideas)

        def update_prompt_type(paper_link: str, pok_type: str, prompt_ideas: List[str]):
            global paper_name
            if paper_name is None:
                paper_name = get_paper_name(paper_link)

            pok_type = resolve_poke_type(pok_type)
            
            return build_prompt_text(paper_name, pok_type, prompt_ideas)
        
        def build_prompt_text(paper_name, pok_type, add_ideas):
            prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
            prompt_text = f"""{prompt_text} {" ".join(add_ideas)}"""
            return prompt_text
        
        text.change(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        text.submit(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        
        poke_type.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        prompt_ideas.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        
        
        btn.click(infer, inputs=[prompt_box, samples, steps, scale], outputs=gallery)
        gr.HTML(
            """
                <div class="footer" style="text-align: center; max-width: 650px; margin: 50px auto;">
                    <p>Inspired by and cloned from the great <a href="https://huggingface.co/spaces/lambdalabs/text-to-pokemon">
                    Text-to-Pokémon</a> space by Lambda labs</p>
                    <p> Gradio Demo by johko</p>
               </div>
           """
        )

block.launch()