import os import sys import re from typing import List, Optional, Tuple, Union import random sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located import numpy as np import PIL.Image import torch import streamlit as st import dnnlib import legacy def parse_range(s: Union[str, List]) -> List[int]: '''Parse a comma separated list of numbers or ranges and return a list of ints. Example: '1,2,5-10' returns [1, 2, 5, 6, 7] ''' if isinstance(s, list): return s ranges = [] range_re = re.compile(r'^(\d+)-(\d+)$') for p in s.split(','): m = range_re.match(p) if m: ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) else: ranges.append(int(p)) return ranges def make_transform(translate: Tuple[float,float], angle: float): m = np.eye(3) s = np.sin(angle/360.0*np.pi*2) c = np.cos(angle/360.0*np.pi*2) m[0][0] = c m[0][1] = s m[0][2] = translate[0] m[1][0] = -s m[1][1] = c m[1][2] = translate[1] return m def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]): print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with open(network_pkl, 'rb') as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore # Labels. label = torch.zeros([1, G.c_dim], device=device) if G.c_dim != 0: if class_idx is None: raise Exception('Must specify class label when using a conditional network') label[:, class_idx] = 1 z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) if hasattr(G.synthesis, 'input'): m = make_transform(translate, rotate) m = np.linalg.inv(m) G.synthesis.input.transform.copy_(torch.from_numpy(m)) img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') return img def main(): st.title('Kpop Face Generator') st.write('Press the button below to generate a new image:') if st.button('Generate'): network_pkl = 'kpopGG.pkl' seed = random.randint(0, 99999) truncation_psi = 0.45 noise_mode = 'const' translate = (0.0, 0.0) rotate = 0.0 class_idx = None image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx) st.image(image) if __name__ == "__main__": main()