rossellison's picture
Upload 159 files
8a860df
import os
import sys
import re
from typing import List, Optional, Tuple, Union
import random
sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located
import numpy as np
import PIL.Image
import torch
import streamlit as st
import dnnlib
import legacy
def parse_range(s: Union[str, List]) -> List[int]:
'''Parse a comma separated list of numbers or ranges and return a list of ints.
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
'''
if isinstance(s, list): return s
ranges = []
range_re = re.compile(r'^(\d+)-(\d+)$')
for p in s.split(','):
m = range_re.match(p)
if m:
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
else:
ranges.append(int(p))
return ranges
def make_transform(translate: Tuple[float,float], angle: float):
m = np.eye(3)
s = np.sin(angle/360.0*np.pi*2)
c = np.cos(angle/360.0*np.pi*2)
m[0][0] = c
m[0][1] = s
m[0][2] = translate[0]
m[1][0] = -s
m[1][1] = c
m[1][2] = translate[1]
return m
def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]):
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with open(network_pkl, 'rb') as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
raise Exception('Must specify class label when using a conditional network')
label[:, class_idx] = 1
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
if hasattr(G.synthesis, 'input'):
m = make_transform(translate, rotate)
m = np.linalg.inv(m)
G.synthesis.input.transform.copy_(torch.from_numpy(m))
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
return img
def main():
st.title('Kpop Face Generator')
st.write('Press the button below to generate a new image:')
if st.button('Generate'):
network_pkl = 'kpopGG.pkl'
seed = random.randint(0, 99999)
truncation_psi = 0.45
noise_mode = 'const'
translate = (0.0, 0.0)
rotate = 0.0
class_idx = None
image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx)
st.image(image)
if __name__ == "__main__":
main()