feng2022 commited on
Commit
ae8b571
1 Parent(s): d603fa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -34
app.py CHANGED
@@ -52,22 +52,18 @@ def parse_args() -> argparse.Namespace:
52
  action='store_false')
53
  parser.add_argument('--allow-flagging', type=str, default='never')
54
  return parser.parse_args()
55
-
56
- def load_model(file_name: str, path:str,device: torch.device) -> nn.Module:
57
- path = hf_hub_download(f'{path}',
58
- f'{file_name}',
59
- use_auth_token=TOKEN)
60
- with open(path, 'rb') as f:
61
- model = torch.load(f)
62
- model.eval()
63
- model.to(device)
64
- with torch.inference_mode():
65
- z = torch.zeros((1, model.z_dim)).to(device)
66
- label = torch.zeros([1, model.c_dim], device=device)
67
- model(z, label, force_fp32=True)
68
- return model
69
 
70
  def image_create(input_img):
 
 
 
 
 
 
 
 
 
 
71
  device = th.device()
72
  generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args, device)
73
  latent = torch.randn((1, 512), device=device)
@@ -82,33 +78,29 @@ def main():
82
  #else:
83
  # ini = "False1"
84
  #result = subprocess.check_output(['nvidia-smi'])
85
- #load_model("stylegan2-ffhq-config-f","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",device)
86
- args = ProjectorArguments().parse(
87
- args=[str(input_path)],
88
- namespace=Namespace(
89
- spectral_sensitivity=spectral_sensitivity,
90
- encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
91
- encoder_name=spectral_sensitivity,
92
- #gaussian=gaussian_radius,
93
- log_visual_freq=1000,
94
- input='text',
95
- ))
96
  iface = gr.Interface(
97
  fn=image_create,
98
- inputs='text',
99
- outputs='image',
 
 
 
 
100
  title=TITLE,
101
  description=DESCRIPTION,
102
  article=ARTICLE,
103
- #theme=args.theme,
104
- #allow_flagging=args.allow_flagging,
105
- #live=args.live,
106
  )
107
 
108
- iface.launch()
109
- #enable_queue=args.enable_queue,
110
- #server_port=args.port,
111
- #share=args.share,)
 
 
112
  if __name__ == '__main__':
113
  main()
114
 
 
52
  action='store_false')
53
  parser.add_argument('--allow-flagging', type=str, default='never')
54
  return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def image_create(input_img):
57
+ args = ProjectorArguments().parse(
58
+ args=[str(input_path)],
59
+ namespace=Namespace(
60
+ spectral_sensitivity=spectral_sensitivity,
61
+ encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
62
+ encoder_name=spectral_sensitivity,
63
+ #gaussian=gaussian_radius,
64
+ log_visual_freq=1000,
65
+ input='text',
66
+ ))
67
  device = th.device()
68
  generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args, device)
69
  latent = torch.randn((1, 512), device=device)
 
78
  #else:
79
  # ini = "False1"
80
  #result = subprocess.check_output(['nvidia-smi'])
81
+ args = parse_args()
 
 
 
 
 
 
 
 
 
 
82
  iface = gr.Interface(
83
  fn=image_create,
84
+ [
85
+ gr.inputs.Number(default=0, label='Seed'),
86
+ gr.inputs.Slider(
87
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
88
+ ],
89
+ gr.outputs.Image(type='numpy', label='Output'),
90
  title=TITLE,
91
  description=DESCRIPTION,
92
  article=ARTICLE,
93
+ theme=args.theme,
94
+ allow_flagging=args.allow_flagging,
95
+ live=args.live,
96
  )
97
 
98
+ iface.launch(
99
+ enable_queue=args.enable_queue,
100
+ server_port=args.port,
101
+ share=args.share,
102
+ )
103
+
104
  if __name__ == '__main__':
105
  main()
106