rinong commited on
Commit
5479d05
1 Parent(s): db3750b

Updated to use arbitrary model paths

Browse files
Files changed (1) hide show
  1. app.py +17 -18
app.py CHANGED
@@ -31,25 +31,22 @@ from generate_videos import generate_frames, video_from_interpolations, vid_to_g
31
  model_dir = "models"
32
  os.makedirs(model_dir, exist_ok=True)
33
 
34
- models_and_paths = {"akhaliq/JoJoGAN_e4e_ffhq_encode": "e4e_ffhq_encode.pt",
35
- "akhaliq/jojogan_dlib": "shape_predictor_68_face_landmarks.dat",
36
- "akhaliq/jojogan-stylegan2-ffhq-config-f": "stylegan2-ffhq-config-f.pt"}
37
 
38
  def get_models():
39
  os.makedirs(model_dir, exist_ok=True)
40
 
41
- for repo_id, file_path in models_and_paths.items():
42
- hf_hub_download(repo_id=repo_id, filename=file_path)
43
- if not "akhaliq" in repo_id:
44
- shutil.move(file_path, os.path.join(model_dir, file_path))
45
- elif "stylegan2" in file_path:
46
- shutil.move(file_path, os.path.join(model_dir, "base.pt"))
47
 
48
- model_list = [Path(model_ckpt).stem for model_ckpt in os.listdir(model_dir)]
 
 
49
 
50
- return model_list
51
 
52
- model_list = get_models()
53
 
54
  class ImageEditor(object):
55
  def __init__(self):
@@ -62,18 +59,20 @@ class ImageEditor(object):
62
 
63
  self.generators = {}
64
 
65
- for model in model_list:
 
 
66
  g_ema = Generator(
67
  model_size, latent_size, n_mlp, channel_multiplier=channel_mult
68
  ).to(self.device)
69
 
70
- checkpoint = torch.load(f"models/{model}.pt")
71
 
72
  g_ema.load_state_dict(checkpoint['g_ema'])
73
 
74
  self.generators[model] = g_ema
75
 
76
- self.experiment_args = {"model_path": "e4e_ffhq_encode.pt"}
77
  self.experiment_args["transform"] = transforms.Compose(
78
  [
79
  transforms.Resize((256, 256)),
@@ -96,7 +95,7 @@ class ImageEditor(object):
96
  self.e4e_net.cuda()
97
 
98
  self.shape_predictor = dlib.shape_predictor(
99
- models_and_paths["akhaliq/jojogan_dlib"]
100
  )
101
 
102
  print("setup complete")
@@ -120,11 +119,11 @@ class ImageEditor(object):
120
  ):
121
 
122
  if output_style == 'all':
123
- styles = model_list
124
  elif output_style == 'list - enter below':
125
  styles = style_list.split(",")
126
  for style in styles:
127
- if style not in model_list:
128
  raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
129
  else:
130
  styles = [output_style]
 
31
  model_dir = "models"
32
  os.makedirs(model_dir, exist_ok=True)
33
 
34
+ model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
35
+ "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
36
+ "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt")}
37
 
38
  def get_models():
39
  os.makedirs(model_dir, exist_ok=True)
40
 
41
+ model_paths = {}
 
 
 
 
 
42
 
43
+ for model_name, repo_details in model_repos.items():
44
+ download_path = hf_hub_download(repo_id=repo_details[0], filename=repo_details[1])
45
+ model_paths[model_name] = download_path
46
 
47
+ return model_paths
48
 
49
+ model_paths = get_models()
50
 
51
  class ImageEditor(object):
52
  def __init__(self):
 
59
 
60
  self.generators = {}
61
 
62
+ self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
63
+
64
+ for model in self.model_list:
65
  g_ema = Generator(
66
  model_size, latent_size, n_mlp, channel_multiplier=channel_mult
67
  ).to(self.device)
68
 
69
+ checkpoint = torch.load(model_paths[model])
70
 
71
  g_ema.load_state_dict(checkpoint['g_ema'])
72
 
73
  self.generators[model] = g_ema
74
 
75
+ self.experiment_args = {"model_path": model_paths["e4e"]}
76
  self.experiment_args["transform"] = transforms.Compose(
77
  [
78
  transforms.Resize((256, 256)),
 
95
  self.e4e_net.cuda()
96
 
97
  self.shape_predictor = dlib.shape_predictor(
98
+ model_paths["dlib"]
99
  )
100
 
101
  print("setup complete")
 
119
  ):
120
 
121
  if output_style == 'all':
122
+ styles = self.model_list
123
  elif output_style == 'list - enter below':
124
  styles = style_list.split(",")
125
  for style in styles:
126
+ if style not in self.model_list:
127
  raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
128
  else:
129
  styles = [output_style]