File size: 2,199 Bytes
40ce629
 
 
 
 
 
f57ed6a
40ce629
f57ed6a
119d5c2
f57ed6a
40ce629
f57ed6a
98a2239
40ce629
38a5e47
cc52c45
 
c405107
51f1e70
0025f06
 
 
18f9c41
 
0025f06
40ce629
092a462
fe030b4
b703853
 
40ce629
27f8154
40ce629
39868fe
28e8fb4
fd26580
ae8b571
 
 
 
9d9b9ef
ed89585
3c99a0a
b74c99a
 
3c99a0a
5266bb1
 
3c99a0a
c80e976
40e259d
b954cf5
 
 
 
 
 
b0dd76b
5266bb1
18cc678
2ed7e26
 
 
 
 
 
b42b8ed
2ed7e26
 
 
40e259d
7335ad6
ae8b571
40e259d
78f6e98
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python

from __future__ import annotations

import argparse
import functools
import os
import pickle
import sys
import subprocess

import gradio as gr
import numpy as np
import torch
import torch.nn as nn

sys.path.append('.')
sys.path.append('./Time_TravelRephotography')
from utils import torch_helpers as th
from argparse import Namespace
from projector import (
    ProjectorArguments,
    main,
    create_generator,
    make_image,
)

input_path = ''
spectral_sensitivity =  'b'
TITLE = 'Time-TravelRephotography'
DESCRIPTION = '''This is an unofficial demo for https://github.com/Time-Travel-Rephotography.
'''
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=Time-TravelRephotography" alt="visitor badge"/></center>'

   
def image_create(seed: int, truncation_psi: float):
    args = ProjectorArguments().parse(
        args=[str(input_path)],
        namespace=Namespace(
            encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
            #gaussian=gaussian_radius,
            log_visual_freq=1000
    ))
    device = th.device()
    generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args,device)
    #generator = create_generator("checkpoint_b.pt.pth","feng2022/Time_TravelRephotography_checkpoint_b",args,device)
    latent = torch.randn((1, 512), device=device) 
    img_out, _, _ = generator([latent])
    imgs_arr = make_image(img_out)
    return imgs_arr[0]/255
    
def main():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        ini = "True1"
    else:
        ini = "False1"
    result = subprocess.check_output(['nvidia-smi'])
    device = th.device()
    iface = gr.Interface(
          image_create,
          [
                gr.inputs.Number(default=0, label='Seed'),
                gr.inputs.Slider(
                    0, 2, step=0.05, default=0.7, label='Truncation psi'),
          ],
          gr.outputs.Image(type='numpy', label='Output'),
          title=TITLE,
          description=DESCRIPTION,
          article=ARTICLE,
          )
    
    iface.launch()
        
if __name__ == '__main__':
    main()