FrozenBurning commited on
Commit
8b9a803
1 Parent(s): 6faf56d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -9,7 +9,7 @@ from tqdm import tqdm
9
 
10
  os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
11
  sys.path.append("SceneDreamer")
12
- os.system("bash install.sh")
13
  pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
14
  alt_url='', file_size=330571863,
15
  file_path='./scenedreamer_released.pt',)
@@ -88,15 +88,27 @@ with requests.Session() as session:
88
 
89
  import os
90
  import torch
 
 
91
  import argparse
92
  from imaginaire.config import Config
93
  from imaginaire.utils.cudnn import init_cudnn
94
- from imaginaire.utils.io import get_checkpoint as get_checkpoint
95
- from imaginaire.utils.trainer import \
96
- (get_model_optimizer_and_scheduler, set_random_seed)
97
  import gradio as gr
98
  from PIL import Image
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def parse_args():
101
  parser = argparse.ArgumentParser(description='Training')
102
  parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
@@ -111,14 +123,17 @@ def parse_args():
111
 
112
 
113
  args = parse_args()
114
- set_random_seed(args.seed, by_rank=False)
115
  cfg = Config(args.config)
116
 
117
  # Initialize cudnn.
118
  init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
119
 
120
  # Initialize data loaders and models.
121
- net_G = get_model_optimizer_and_scheduler(cfg, seed=args.seed, generator_only=True)
 
 
 
 
122
 
123
  if args.checkpoint == '':
124
  raise NotImplementedError("No checkpoint is provided for inference!")
 
9
 
10
  os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
11
  sys.path.append("SceneDreamer")
12
+
13
  pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
14
  alt_url='', file_size=330571863,
15
  file_path='./scenedreamer_released.pt',)
 
88
 
89
  import os
90
  import torch
91
+ import torch.nn as nn
92
+ import importlib
93
  import argparse
94
  from imaginaire.config import Config
95
  from imaginaire.utils.cudnn import init_cudnn
 
 
 
96
  import gradio as gr
97
  from PIL import Image
98
 
99
+
100
+ class WrappedModel(nn.Module):
101
+ r"""Dummy wrapping the module.
102
+ """
103
+
104
+ def __init__(self, module):
105
+ super(WrappedModel, self).__init__()
106
+ self.module = module
107
+
108
+ def forward(self, *args, **kwargs):
109
+ r"""PyTorch module forward function overload."""
110
+ return self.module(*args, **kwargs)
111
+
112
  def parse_args():
113
  parser = argparse.ArgumentParser(description='Training')
114
  parser.add_argument('--config', type=str, default='./configs/scenedreamer_inference.yaml', help='Path to the training config file.')
 
123
 
124
 
125
  args = parse_args()
 
126
  cfg = Config(args.config)
127
 
128
  # Initialize cudnn.
129
  init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
130
 
131
  # Initialize data loaders and models.
132
+
133
+ lib_G = importlib.import_module(cfg.gen.type)
134
+ net_G = lib_G.Generator(cfg.gen, cfg.data)
135
+ net_G = net_G.to('cuda')
136
+ net_G = WrappedModel(net_G)
137
 
138
  if args.checkpoint == '':
139
  raise NotImplementedError("No checkpoint is provided for inference!")