cocktailpeanut commited on
Commit
57e2fd5
1 Parent(s): 0c47cec
Files changed (2) hide show
  1. app.py +13 -6
  2. requirements.txt +6 -6
app.py CHANGED
@@ -22,6 +22,13 @@ magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s
22
  magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t"
23
  magic_text_encoder_path = "./ckpts/Magic_Weights/magic_text_encoder"
24
 
 
 
 
 
 
 
 
25
  css = """
26
  .toolbutton {
27
  margin-buttom: 0em 0em 0em 0em;
@@ -87,9 +94,9 @@ class MagicTimeController:
87
  self.inference_config = OmegaConf.load(inference_config_path)[1]
88
 
89
  self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
90
- self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
91
- self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
92
- self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
93
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
94
 
95
  # self.tokenizer = tokenizer
@@ -162,7 +169,7 @@ class MagicTimeController:
162
  pipeline = MagicTimePipeline(
163
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
164
  scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
165
- ).to("cuda")
166
 
167
  if int(seed_textbox) > 0: seed = int(seed_textbox)
168
  else: seed = random.randint(1, 1e16)
@@ -171,7 +178,7 @@ class MagicTimeController:
171
  assert seed == torch.initial_seed()
172
  print(f"### seed: {seed}")
173
 
174
- generator = torch.Generator(device="cuda")
175
  generator.manual_seed(seed)
176
 
177
  sample = pipeline(
@@ -256,4 +263,4 @@ def ui():
256
  if __name__ == "__main__":
257
  demo = ui()
258
  demo.queue(max_size=20)
259
- demo.launch()
 
22
  magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t"
23
  magic_text_encoder_path = "./ckpts/Magic_Weights/magic_text_encoder"
24
 
25
+ if torch.cuda.is_available():
26
+ device = "cuda"
27
+ elif torch.backends.mps.is_available():
28
+ device = "mps"
29
+ else:
30
+ device = "cpu"
31
+
32
  css = """
33
  .toolbutton {
34
  margin-buttom: 0em 0em 0em 0em;
 
94
  self.inference_config = OmegaConf.load(inference_config_path)[1]
95
 
96
  self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
97
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
98
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
99
+ self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
100
  self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
101
 
102
  # self.tokenizer = tokenizer
 
169
  pipeline = MagicTimePipeline(
170
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
171
  scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
172
+ ).to(device)
173
 
174
  if int(seed_textbox) > 0: seed = int(seed_textbox)
175
  else: seed = random.randint(1, 1e16)
 
178
  assert seed == torch.initial_seed()
179
  print(f"### seed: {seed}")
180
 
181
+ generator = torch.Generator(device=device)
182
  generator.manual_seed(seed)
183
 
184
  sample = pipeline(
 
263
  if __name__ == "__main__":
264
  demo = ui()
265
  demo.queue(max_size=20)
266
+ demo.launch()
requirements.txt CHANGED
@@ -1,17 +1,17 @@
1
- torch==2.2.2
2
- torchvision==0.17.2
3
- torchaudio==2.2.2
4
- xformers==0.0.25.post1
5
  imageio==2.27.0
6
  gdown
7
  einops
8
  omegaconf
9
  safetensors
10
  gradio
11
- triton
12
  imageio[ffmpeg]
13
  imageio[pyav]
14
  ms-swift
15
  accelerate==0.28.0
16
  diffusers==0.11.1
17
- transformers==4.38.2
 
1
+ #torch==2.2.2
2
+ #torchvision==0.17.2
3
+ #torchaudio==2.2.2
4
+ #xformers==0.0.25.post1
5
  imageio==2.27.0
6
  gdown
7
  einops
8
  omegaconf
9
  safetensors
10
  gradio
11
+ #triton
12
  imageio[ffmpeg]
13
  imageio[pyav]
14
  ms-swift
15
  accelerate==0.28.0
16
  diffusers==0.11.1
17
+ transformers==4.38.2