File size: 2,352 Bytes
b41a54a
 
 
35ed471
b41a54a
35ed471
 
 
b41a54a
 
 
 
 
 
 
 
 
 
 
35ed471
b41a54a
 
 
 
 
 
 
 
 
 
 
35ed471
 
b41a54a
35ed471
b969123
35ed471
b41a54a
35ed471
b41a54a
 
35ed471
b41a54a
 
 
 
35ed471
b41a54a
 
 
 
 
 
 
 
 
 
b969123
 
b41a54a
 
 
 
 
 
b969123
 
b41a54a
35ed471
b41a54a
 
 
 
 
 
b969123
 
 
0cda7f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# system

import os
from pathlib import Path

if not Path('./Text2Punk-final-7.pt').exists() and not Path('./clip-final.pt').exists():
    os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH")
    os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF")

# plot

from PIL import Image

# gradio

import gradio as gr

# text2punks utils

from text2punks.utils import resize, to_pil_image, model_loader, generate_image

# nobs to tune

top_k = 0.8
temperature = 1.25

# helper functions

def compose_predictions(images):

    increased_h = 0
    b, c, h, w = *images.shape,
    image_grid = Image.new("RGB", (b*w*4, h*4), color=0)

    for i in range(b):
        img_ = to_pil_image(resize(images[i], 96))
        image_grid.paste(img_, (i*w*4, increased_h))

    return image_grid


def run_inference(prompt, num_images=32, batch_size=32, num_preds=8):

    t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt'
    text2punk, clip = model_loader(t2p_path, clip_path)

    images, _ = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=num_preds, text2punk_model=text2punk, clip_model=clip)
    predictions = compose_predictions(images)

    output_title = f"""
    <b>{prompt}</b>
    """

    return (output_title, predictions)


outputs = [
    gr.outputs.HTML(label="Output Title"),      # To be used as title
    gr.outputs.Image(type="pil", label="Output Image"),
]

description = """
Text2Cryptopunks is an AI model that generates Cryptopunks images from text prompt:
"""

gr.Interface(
    fn=run_inference, 
    inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')],
    outputs=outputs,
    title='Text2Cryptopunks',
    description=description,
    article="<p style='text-align: center'> Created by kTonpa | <a href='https://github.com/kTonpa/Text2CryptoPunks'>GitHub</a>",
    layout='vertical',
    theme='huggingface',
    examples=[['Cute Alien cryptopunk that has a 2 Attributes, a Pipe, and a Beanie.'], ['A low resolution photo of punky-looking Ape that has 2 Attributes, a Beanie, and a Medical Mask.']],
    # allow_flagging=False,
    # live=False,
    # verbose=True,
).launch(enable_queue=True)