DA-2 / da2 /utils /model.py
haodongli's picture
update
d82e7f9
raw
history blame contribute delete
546 Bytes
import torch
from ..model.spherevit import SphereViT
def load_model(config, accelerator):
model = SphereViT.from_pretrained('haodongli/DA-2', config=config)
model = model.to(accelerator.device)
torch.cuda.empty_cache()
model = accelerator.prepare(model)
if accelerator.num_processes > 1:
model = model.module
if config['env']['verbose']:
config['env']['logger'].info(f'Model\'s dtype: {next(model.parameters()).dtype}.')
config['spherevit']['dtype'] = next(model.parameters()).dtype
return model