File size: 2,093 Bytes
78ae3cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import functools
import pickle
import random
from typing import List
import numpy as np
import streamlit as st
import torch
from huggingface_hub import hf_hub_url, cached_download
ICON_CLASS_MAPPING = {
"Fire": 8,
"Magic": 7,
"Nature": 6,
"Lightning": 5,
"Ice": 4,
"Shadow": 3,
"Unholy": 2,
"Battle": 1,
"Holy": 0,
}
MAX_SEED = 100000000
st.title("RPG Icon Generator")
with open(
cached_download(hf_hub_url("gylleus/rpg-icongen", "icongen-model.pkl")), "rb"
) as f:
G = pickle.load(f)["G_ema"] # torch.nn.Module
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
G = G.to(device)
else:
G.forward = functools.partial(G.forward, force_fp32=True)
random_seed = 0
def randomize_seed() -> int:
global random_seed
random_seed = random.randint(0, MAX_SEED)
randomize_seed()
def get_class_id(class_name: str):
if class_name in ICON_CLASS_MAPPING:
return ICON_CLASS_MAPPING[class_name]
return ICON_CLASS_MAPPING["Fire"]
def generate(seed: int, class_name: str) -> np.ndarray:
label = torch.zeros([1, G.c_dim], device=device)
# set chosen class
label[:, get_class_id(class_name)] = 1
truncation_psi = 1
noise_mode = "const"
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return img.cpu().numpy()
def generate_images(seed: int, amount: int, class_name: str) -> List[np.ndarray]:
return [generate(i, class_name) for i in range(seed, seed + amount)]
st.button("Generate", on_click=randomize_seed())
chosen_class = st.selectbox("Choose icon type", tuple(ICON_CLASS_MAPPING.keys()))
image_amount = st.slider("Images to generate", 1, 9, 3)
columns = st.columns(3)
column_index = 0
for img in generate_images(random_seed, image_amount, chosen_class):
column = columns[column_index % len(columns)]
column.image(img)
column_index += 1
|