lopho commited on
Commit
810b825
1 Parent(s): 3b76971

debug code

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. makeavid_sd/inference.py +26 -14
app.py CHANGED
@@ -23,6 +23,9 @@ _model = InferenceUNetPseudo3D(
23
  hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
24
  )
25
 
 
 
 
26
  # gradio is illiterate. type hints make it go poopoo in pantsu.
27
  def generate(
28
  prompt = 'An elderly man having a great time in the park.',
 
23
  hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
24
  )
25
 
26
+ if _model.failed == True:
27
+ exit()
28
+
29
  # gradio is illiterate. type hints make it go poopoo in pantsu.
30
  def generate(
31
  prompt = 'An elderly man having a great time in the park.',
makeavid_sd/inference.py CHANGED
@@ -62,20 +62,32 @@ class InferenceUNetPseudo3D:
62
  self.hf_auth_token = hf_auth_token
63
 
64
  self.params: Dict[str, FrozenDict[str, Any]] = {}
65
- unet, unet_params = UNetPseudo3DConditionModel.from_pretrained(
66
- self.model_path,
67
- subfolder = 'unet',
68
- from_pt = False,
69
- sample_size = (64, 64),
70
- dtype = self.dtype,
71
- param_dtype = dtypestr(self.dtype),
72
- use_memory_efficient_attention = True,
73
- use_auth_token = self.hf_auth_token
74
- )
75
- self.unet: UNetPseudo3DConditionModel = unet
76
- unet_params = castto(self.dtype, self.unet, unet_params)
77
- self.params['unet'] = FrozenDict(unet_params)
78
- del unet_params
 
 
 
 
 
 
 
 
 
 
 
 
79
  vae, vae_params = FlaxAutoencoderKL.from_pretrained(
80
  self.model_path,
81
  subfolder = 'vae',
 
62
  self.hf_auth_token = hf_auth_token
63
 
64
  self.params: Dict[str, FrozenDict[str, Any]] = {}
65
+ try:
66
+ import traceback
67
+ print('initializing unet')
68
+ unet, unet_params = UNetPseudo3DConditionModel.from_pretrained(
69
+ self.model_path,
70
+ subfolder = 'unet',
71
+ from_pt = False,
72
+ sample_size = (64, 64),
73
+ dtype = self.dtype,
74
+ param_dtype = dtypestr(self.dtype),
75
+ use_memory_efficient_attention = True,
76
+ use_auth_token = self.hf_auth_token
77
+ )
78
+ self.unet: UNetPseudo3DConditionModel = unet
79
+ print('casting unet params')
80
+ unet_params = castto(self.dtype, self.unet, unet_params)
81
+ print('storing unet params')
82
+ self.params['unet'] = FrozenDict(unet_params)
83
+ print('deleting unet params')
84
+ del unet_params
85
+ except Exception as e:
86
+ print(e)
87
+ traceback.print_exc()
88
+ self.failed = True
89
+ return
90
+ self.failed = False
91
  vae, vae_params = FlaxAutoencoderKL.from_pretrained(
92
  self.model_path,
93
  subfolder = 'vae',