Spaces:
Sleeping
Sleeping
File size: 1,597 Bytes
dfd33e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import argparse
import glob
import os
import cv2
import numpy
import torch
from PIL import Image
from Model import TRCaptionNet, clip_transform
def demo(opt):
preprocess = clip_transform(224)
model = TRCaptionNet({
"max_length": 35,
"clip": "ViT-L/14",
"bert": "dbmdz/bert-base-turkish-cased",
"proj": True,
"proj_num_head": 16
})
device = torch.device(opt.device)
model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True)
model = model.to(device)
model.eval()
image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg'))
for image_path in sorted(image_paths):
img_name = image_path.split('/')[-1]
img0 = Image.open(image_path)
batch = preprocess(img0).unsqueeze(0).to(device)
caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0]
print(f"{img_name} :", caption)
orj_img = numpy.array(img0)[:, :, ::-1]
h, w, _ = orj_img.shape
new_h = 800
new_w = int(new_h * (w / h))
orj_img = cv2.resize(orj_img, (new_w, new_h))
cv2.imshow("image", orj_img)
cv2.waitKey(0)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!')
parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth')
parser.add_argument('--input-dir', type=str, default='./images/')
parser.add_argument('--device', type=str, default='cuda:0')
args = parser.parse_args()
demo(args)
|