YiftachEde commited on
Commit
2fc2bf3
·
1 Parent(s): 49f568d
Files changed (1) hide show
  1. app.py +62 -14
app.py CHANGED
@@ -16,6 +16,8 @@ from shap_e.util.notebooks import create_pan_cameras, decode_latent_images
16
  import spaces
17
  from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
18
  import math
 
 
19
 
20
  from src.utils.train_util import instantiate_from_config
21
  from src.utils.camera_util import (
@@ -83,13 +85,29 @@ def load_models():
83
 
84
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
 
86
- # Load diffusion pipeline
87
  print('Loading diffusion pipeline...')
88
- pipeline = DiffusionPipeline.from_pretrained(
89
- "sudo-ai/zero123plus-v1.2",
90
- custom_pipeline="zero123plus",
91
- torch_dtype=torch.float16
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
94
  pipeline.scheduler.config, timestep_spacing='trailing'
95
  )
@@ -107,19 +125,49 @@ def load_models():
107
  new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
108
  pipeline.unet.conv_in = new_conv_in
109
 
110
- # Load custom UNet
111
  print('Loading custom UNet...')
112
- pipeline.unet = pipeline.unet.from_pretrained("YiftachEde/Sharp-It").to(torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
114
 
115
- # Load reconstruction model
116
  print('Loading reconstruction model...')
117
  model = instantiate_from_config(model_config)
118
- model_path = hf_hub_download(
119
- repo_id="TencentARC/InstantMesh",
120
- filename="instant_nerf_large.ckpt",
121
- repo_type="model"
122
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  state_dict = torch.load(model_path, map_location='cpu')['state_dict']
124
  state_dict = {k[14:]: v for k, v in state_dict.items()
125
  if k.startswith('lrm_generator.') and 'source_camera' not in k}
 
16
  import spaces
17
  from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
18
  import math
19
+ import time
20
+ from requests.exceptions import ReadTimeout, ConnectionError
21
 
22
  from src.utils.train_util import instantiate_from_config
23
  from src.utils.camera_util import (
 
85
 
86
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87
 
88
+ # Load diffusion pipeline with retry logic
89
  print('Loading diffusion pipeline...')
90
+ max_retries = 3
91
+ retry_delay = 5
92
+
93
+ for attempt in range(max_retries):
94
+ try:
95
+ pipeline = DiffusionPipeline.from_pretrained(
96
+ "sudo-ai/zero123plus-v1.2",
97
+ custom_pipeline="zero123plus",
98
+ torch_dtype=torch.float16,
99
+ local_files_only=False,
100
+ resume_download=True,
101
+ token=True # Use token-based auth
102
+ )
103
+ break
104
+ except (ReadTimeout, ConnectionError) as e:
105
+ if attempt == max_retries - 1:
106
+ raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}")
107
+ print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
108
+ time.sleep(retry_delay)
109
+ retry_delay *= 2 # Exponential backoff
110
+
111
  pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
112
  pipeline.scheduler.config, timestep_spacing='trailing'
113
  )
 
125
  new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
126
  pipeline.unet.conv_in = new_conv_in
127
 
128
+ # Load custom UNet with retry logic
129
  print('Loading custom UNet...')
130
+ for attempt in range(max_retries):
131
+ try:
132
+ pipeline.unet = pipeline.unet.from_pretrained(
133
+ "YiftachEde/Sharp-It",
134
+ local_files_only=False,
135
+ resume_download=True,
136
+ token=True # Use token-based auth
137
+ ).to(torch.float16)
138
+ break
139
+ except (ReadTimeout, ConnectionError) as e:
140
+ if attempt == max_retries - 1:
141
+ raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}")
142
+ print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
143
+ time.sleep(retry_delay)
144
+ retry_delay *= 2
145
+
146
  pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
147
 
148
+ # Load reconstruction model with retry logic
149
  print('Loading reconstruction model...')
150
  model = instantiate_from_config(model_config)
151
+
152
+ for attempt in range(max_retries):
153
+ try:
154
+ model_path = hf_hub_download(
155
+ repo_id="TencentARC/InstantMesh",
156
+ filename="instant_nerf_large.ckpt",
157
+ repo_type="model",
158
+ local_files_only=False,
159
+ resume_download=True,
160
+ token=True, # Use token-based auth
161
+ cache_dir="model_cache" # Use a specific cache directory
162
+ )
163
+ break
164
+ except (ReadTimeout, ConnectionError) as e:
165
+ if attempt == max_retries - 1:
166
+ raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}")
167
+ print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
168
+ time.sleep(retry_delay)
169
+ retry_delay *= 2
170
+
171
  state_dict = torch.load(model_path, map_location='cpu')['state_dict']
172
  state_dict = {k[14:]: v for k, v in state_dict.items()
173
  if k.startswith('lrm_generator.') and 'source_camera' not in k}