Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import re | |
from typing import List, Optional, Tuple, Union | |
import random | |
sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located | |
import numpy as np | |
import PIL.Image | |
import torch | |
import streamlit as st | |
import dnnlib | |
import legacy | |
def parse_range(s: Union[str, List]) -> List[int]: | |
'''Parse a comma separated list of numbers or ranges and return a list of ints. | |
Example: '1,2,5-10' returns [1, 2, 5, 6, 7] | |
''' | |
if isinstance(s, list): return s | |
ranges = [] | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
for p in s.split(','): | |
m = range_re.match(p) | |
if m: | |
ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) | |
else: | |
ranges.append(int(p)) | |
return ranges | |
def make_transform(translate: Tuple[float,float], angle: float): | |
m = np.eye(3) | |
s = np.sin(angle/360.0*np.pi*2) | |
c = np.cos(angle/360.0*np.pi*2) | |
m[0][0] = c | |
m[0][1] = s | |
m[0][2] = translate[0] | |
m[1][0] = -s | |
m[1][1] = c | |
m[1][2] = translate[1] | |
return m | |
def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]): | |
print('Loading networks from "%s"...' % network_pkl) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
with open(network_pkl, 'rb') as f: | |
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
# Labels. | |
label = torch.zeros([1, G.c_dim], device=device) | |
if G.c_dim != 0: | |
if class_idx is None: | |
raise Exception('Must specify class label when using a conditional network') | |
label[:, class_idx] = 1 | |
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) | |
if hasattr(G.synthesis, 'input'): | |
m = make_transform(translate, rotate) | |
m = np.linalg.inv(m) | |
G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') | |
return img | |
def main(): | |
st.title('Kpop Face Generator') | |
st.write('Press the button below to generate a new image:') | |
if st.button('Generate'): | |
network_pkl = 'kpopGG.pkl' | |
seed = random.randint(0, 99999) | |
truncation_psi = 0.45 | |
noise_mode = 'const' | |
translate = (0.0, 0.0) | |
rotate = 0.0 | |
class_idx = None | |
image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx) | |
st.image(image) | |
if __name__ == "__main__": | |
main() | |