Biomap / biomap /inference.py
jeremyLE-Ekimetrics's picture
streamlit
9fcd62f
raw
history blame
2.09 kB
import torch.multiprocessing
import torchvision.transforms as T
from utils import transform_to_pil
import logging
preprocess = T.Compose(
[
T.ToPILImage(),
T.Resize((320, 320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def inference(images, model):
logging.info("Inference on Images")
x = torch.stack([preprocess(image) for image in images]).cpu()
with torch.no_grad():
_, code = model.net(x)
linear_pred = model.linear_probe(x, code)
linear_pred = linear_pred.argmax(1)
outputs = [{
"img": x[i].detach().cpu(),
"linear_preds": linear_pred[i].detach().cpu(),
} for i in range(x.shape[0])]
return outputs
if __name__ == "__main__":
import hydra
from model import LitUnsupervisedSegmenter
from utils_gee import extract_img, transform_ee_img
latitude = 2.98
longitude = 48.81
start_date = '2020-03-20'
end_date = '2020-04-20'
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
img = extract_img(location, start_date, end_date)
image = transform_ee_img(
img, max=0.3
) # max value is the value from numpy file that will be equal to 255
print("image loaded")
# Initialize hydra with configs
hydra.initialize(config_path="configs", job_name="corine")
cfg = hydra.compose(config_name="my_train_config.yml")
# Load the model
model_path = "checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
print("model initialized")
model.load_state_dict(saved_state_dict)
print("model loaded")
# img.save("output/image.png")
inference([image], model)
inference([image,image], model)