import glob import os import gdown import gradio as gr import torch from torchvision import transforms from Model import TRCaptionNetpp model_ckpt = "./checkpoints/TRCaptionNetpp_Large.pth" os.makedirs("./checkpoints/", exist_ok=True) url = "https://drive.google.com/uc?id=1tOiRtIpe99gQWnpGfy_W5xgtsHFhvU3F" gdown.download(url, model_ckpt, quiet=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") preprocess = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) model = TRCaptionNetpp( { "max_length": 35, "dino2": "dinov2_vitl14", "bert": "dbmdz/electra-base-turkish-mc4-cased-discriminator", "proj": True, "proj_num_head": 16, } ) ckpt = torch.load(model_ckpt, map_location=device) model.load_state_dict(ckpt["model"], strict=True) model = model.to(device) model.eval() def inference(raw_image, min_length, repetition_penalty): batch = preprocess(raw_image).unsqueeze(0).to(device) caption = model.generate( batch, min_length=int(min_length), repetition_penalty=float(repetition_penalty), )[0] return caption # ----- UI ----- img_input = gr.Image(type="pil", interactive=True, label="Input Image") minlen_slider = gr.Slider( minimum=6, maximum=22, value=11, step=1, label="MINIMUM CAPTION LENGTH" ) rep_slider = gr.Slider( minimum=1.0, maximum=3.0, value=2.5, step=0.1, label="REPETITION PENALTY" ) outputs = gr.Textbox(label="Caption") title = "TRCaptionNet" paper_link = "" # add if available github_link = "https://github.com/serdaryildiz/TRCaptionNetpp" description = ( f"
" f"TRCaptionNet++: " f"A high-performance encoder–decoder based Turkish image captioning model " f"fine-tuned with a large-scale pretrain dataset.
" ) article = ( f"" f"Paper | " f"Github Repo
" ) css = ".output-image, .input-image, .image-preview {height: 600px !important}" # Build examples with full rows (image, min_length, repetition_penalty) imgs = glob.glob("images/*") if imgs: examples = [[p, 11, 2.0] for p in imgs] cache_examples = True else: examples = None cache_examples = False # avoid startup caching when there are no examples iface = gr.Interface( fn=inference, inputs=[img_input, minlen_slider, rep_slider], outputs=outputs, title=title, description=description, examples=examples, cache_examples=cache_examples, article=article, css=css, ) if __name__ == "__main__": # If you still hit caching issues, you can also set: ssr_mode=False iface.launch(server_name="0.0.0.0", server_port=7860)