from strhub.data.module import SceneTextDataModule from strhub.models.utils import load_from_checkpoint from post import filter_mask import segmentation_models_pytorch as smp import albumentations as albu from torchvision import transforms from PIL import Image import torch import cv2 model_recog = load_from_checkpoint("weights/parseq/last.ckpt").eval().to("cpu") img_transform = SceneTextDataModule.get_transform(model_recog.hparams.img_size) model = torch.load('weights/best_model.pth').to("cpu") model.eval() model.float() SHAPE_X = 384 SHAPE_Y = 384 def prediction(image_path): image = cv2.imread(image_path) image_original = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50') transform = albu.Compose([ albu.Lambda(image=preprocessing_fn), albu.Resize(SHAPE_X, SHAPE_Y) ]) image_result = transform(image=image_original)["image"] transform = transforms.ToTensor() tensor = transform(image_result) tensor = torch.unsqueeze(tensor, 0) output = model.predict(tensor.float()) result, img_vis = filter_mask(output, image_original ) image = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) im_pil = Image.fromarray(image) image = img_transform(im_pil).unsqueeze(0).to("cpu") p = model_recog(image).softmax(-1) pred, p = model_recog.tokenizer.decode(p) print(f'{image_path}: {pred[0]}') return img_vis, pred[0]