File size: 2,093 Bytes
78ae3cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import functools
import pickle
import random
from typing import List

import numpy as np
import streamlit as st
import torch

from huggingface_hub import hf_hub_url, cached_download

ICON_CLASS_MAPPING = {
    "Fire": 8,
    "Magic": 7,
    "Nature": 6,
    "Lightning": 5,
    "Ice": 4,
    "Shadow": 3,
    "Unholy": 2,
    "Battle": 1,
    "Holy": 0,
}

MAX_SEED = 100000000

st.title("RPG Icon Generator")

with open(
    cached_download(hf_hub_url("gylleus/rpg-icongen", "icongen-model.pkl")), "rb"
) as f:
    G = pickle.load(f)["G_ema"]  # torch.nn.Module

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    G = G.to(device)
else:
    G.forward = functools.partial(G.forward, force_fp32=True)


random_seed = 0


def randomize_seed() -> int:
    global random_seed
    random_seed = random.randint(0, MAX_SEED)


randomize_seed()


def get_class_id(class_name: str):
    if class_name in ICON_CLASS_MAPPING:
        return ICON_CLASS_MAPPING[class_name]
    return ICON_CLASS_MAPPING["Fire"]


def generate(seed: int, class_name: str) -> np.ndarray:
    label = torch.zeros([1, G.c_dim], device=device)
    # set chosen class
    label[:, get_class_id(class_name)] = 1
    truncation_psi = 1
    noise_mode = "const"

    z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
    img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return img.cpu().numpy()


def generate_images(seed: int, amount: int, class_name: str) -> List[np.ndarray]:
    return [generate(i, class_name) for i in range(seed, seed + amount)]


st.button("Generate", on_click=randomize_seed())

chosen_class = st.selectbox("Choose icon type", tuple(ICON_CLASS_MAPPING.keys()))

image_amount = st.slider("Images to generate", 1, 9, 3)

columns = st.columns(3)

column_index = 0
for img in generate_images(random_seed, image_amount, chosen_class):
    column = columns[column_index % len(columns)]
    column.image(img)
    column_index += 1