Capx
/

WhereAmAt / load.py
Alyosha11's picture
Upload 8 files
5e83696 verified
raw
history blame contribute delete
584 Bytes
from main import *
def get_satclip(ckpt_path, device, return_all=False):
ckpt = torch.load(ckpt_path,map_location=device)
ckpt['hyper_parameters'].pop('eval_downstream')
ckpt['hyper_parameters'].pop('air_temp_data_path')
ckpt['hyper_parameters'].pop('election_data_path')
lightning_model = SatCLIPLightningModule(**ckpt['hyper_parameters']).to(device)
lightning_model.load_state_dict(ckpt['state_dict'])
lightning_model.eval()
geo_model = lightning_model.model
if return_all:
return geo_model
else:
return geo_model.location