hysts HF staff commited on
Commit
765072b
1 Parent(s): c0a7c3c
Files changed (1) hide show
  1. inference.py +7 -4
inference.py CHANGED
@@ -55,12 +55,15 @@ class InferencePipeline:
55
  if model_id == self.model_id:
56
  return
57
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
58
- unet = UNet3DConditionModel.from_pretrained(model_id,
59
- subfolder='unet',
60
- torch_dtype=torch.float16)
 
 
61
  pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
62
  unet=unet,
63
- torch_dtype=torch.float16)
 
64
  pipe = pipe.to(self.device)
65
  self.pipe = pipe
66
  self.model_id = model_id # type: ignore
 
55
  if model_id == self.model_id:
56
  return
57
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
58
+ unet = UNet3DConditionModel.from_pretrained(
59
+ model_id,
60
+ subfolder='unet',
61
+ torch_dtype=torch.float16,
62
+ use_auth_token=self.hf_token)
63
  pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
64
  unet=unet,
65
+ torch_dtype=torch.float16,
66
+ use_auth_token=self.hf_token)
67
  pipe = pipe.to(self.device)
68
  self.pipe = pipe
69
  self.model_id = model_id # type: ignore