Spaces:
Runtime error
Runtime error
import torch.multiprocessing | |
import torchvision.transforms as T | |
from utils import transform_to_pil | |
import logging | |
preprocess = T.Compose( | |
[ | |
T.ToPILImage(), | |
T.Resize((320, 320)), | |
# T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
def inference(images, model): | |
logging.info("Inference on Images") | |
x = torch.stack([preprocess(image) for image in images]).cpu() | |
with torch.no_grad(): | |
_, code = model.net(x) | |
linear_pred = model.linear_probe(x, code) | |
linear_pred = linear_pred.argmax(1) | |
outputs = [{ | |
"img": x[i].detach().cpu(), | |
"linear_preds": linear_pred[i].detach().cpu(), | |
} for i in range(x.shape[0])] | |
return outputs | |
if __name__ == "__main__": | |
import hydra | |
from model import LitUnsupervisedSegmenter | |
from utils_gee import extract_img, transform_ee_img | |
latitude = 2.98 | |
longitude = 48.81 | |
start_date = '2020-03-20' | |
end_date = '2020-04-20' | |
location = [float(latitude), float(longitude)] | |
# Extract img numpy from earth engine and transform it to PIL img | |
img = extract_img(location, start_date, end_date) | |
image = transform_ee_img( | |
img, max=0.3 | |
) # max value is the value from numpy file that will be equal to 255 | |
print("image loaded") | |
# Initialize hydra with configs | |
hydra.initialize(config_path="configs", job_name="corine") | |
cfg = hydra.compose(config_name="my_train_config.yml") | |
# Load the model | |
model_path = "checkpoint/model/model.pt" | |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
nbclasses = cfg.dir_dataset_n_classes | |
model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
print("model initialized") | |
model.load_state_dict(saved_state_dict) | |
print("model loaded") | |
# img.save("output/image.png") | |
inference([image], model) | |
inference([image,image], model) | |