serdaryildiz commited on
Commit
3bc9d68
·
verified ·
1 Parent(s): eb1fbaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -29
app.py CHANGED
@@ -11,53 +11,93 @@ from Model import TRCaptionNetpp
11
  model_ckpt = "./checkpoints/TRCaptionNetpp_Large.pth"
12
 
13
  os.makedirs("./checkpoints/", exist_ok=True)
14
- url = 'https://drive.google.com/uc?id=1tOiRtIpe99gQWnpGfy_W5xgtsHFhvU3F'
15
  gdown.download(url, model_ckpt, quiet=False)
16
 
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
 
19
- preprocess = transforms.Compose([transforms.Resize((224, 224)),
20
- transforms.ToTensor(),
21
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
22
- std=[0.229, 0.224, 0.225])])
 
 
 
 
23
 
24
- model = TRCaptionNetpp({
 
25
  "max_length": 35,
26
  "dino2": "dinov2_vitl14",
27
  "bert": "dbmdz/electra-base-turkish-mc4-cased-discriminator",
28
  "proj": True,
29
- "proj_num_head": 16
30
- })
31
- model.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
 
 
32
  model = model.to(device)
33
  model.eval()
34
 
35
 
36
  def inference(raw_image, min_length, repetition_penalty):
37
  batch = preprocess(raw_image).unsqueeze(0).to(device)
38
- caption = model.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
 
 
 
 
39
  return caption
40
 
41
 
42
- inputs = [gr.Image(type='pil', interactive=True,),
43
- gr.Slider(minimum=6, maximum=22, value=11, label="MINIMUM CAPTION LENGTH", step=1),
44
- gr.Slider(minimum=1, maximum=2, value=2.5, label="REPETITION PENALTY")]
45
- outputs = gr.components.Textbox(label="Caption")
 
 
 
 
 
 
 
46
  title = "TRCaptionNet"
47
- paper_link = ""
48
  github_link = "https://github.com/serdaryildiz/TRCaptionNetpp"
49
- description = f"<p style='text-align: center'><a href='{github_link}' target='_blank'>TRCaptionNet++: A high-performance encoder-decoder based deep Turkish image captioning model fine-tuned with a large-scale set of pretrain data"
50
- examples = [[p] for p in glob.glob("images/*")]
51
-
52
- article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
 
 
 
 
 
 
 
53
  css = ".output-image, .input-image, .image-preview {height: 600px !important}"
54
 
55
- iface = gr.Interface(fn=inference,
56
- inputs=inputs,
57
- outputs=outputs,
58
- title=title,
59
- description=description,
60
- examples=examples,
61
- article=article,
62
- css=css)
63
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model_ckpt = "./checkpoints/TRCaptionNetpp_Large.pth"
12
 
13
  os.makedirs("./checkpoints/", exist_ok=True)
14
+ url = "https://drive.google.com/uc?id=1tOiRtIpe99gQWnpGfy_W5xgtsHFhvU3F"
15
  gdown.download(url, model_ckpt, quiet=False)
16
 
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
+ preprocess = transforms.Compose(
20
+ [
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225]),
25
+ ]
26
+ )
27
 
28
+ model = TRCaptionNetpp(
29
+ {
30
  "max_length": 35,
31
  "dino2": "dinov2_vitl14",
32
  "bert": "dbmdz/electra-base-turkish-mc4-cased-discriminator",
33
  "proj": True,
34
+ "proj_num_head": 16,
35
+ }
36
+ )
37
+ ckpt = torch.load(model_ckpt, map_location=device)
38
+ model.load_state_dict(ckpt["model"], strict=True)
39
  model = model.to(device)
40
  model.eval()
41
 
42
 
43
  def inference(raw_image, min_length, repetition_penalty):
44
  batch = preprocess(raw_image).unsqueeze(0).to(device)
45
+ caption = model.generate(
46
+ batch,
47
+ min_length=int(min_length),
48
+ repetition_penalty=float(repetition_penalty),
49
+ )[0]
50
  return caption
51
 
52
 
53
+ # ----- UI -----
54
+ img_input = gr.Image(type="pil", interactive=True, label="Input Image")
55
+ minlen_slider = gr.Slider(
56
+ minimum=6, maximum=22, value=11, step=1, label="MINIMUM CAPTION LENGTH"
57
+ )
58
+ rep_slider = gr.Slider(
59
+ minimum=1.0, maximum=3.0, value=2.5, step=0.1, label="REPETITION PENALTY"
60
+ )
61
+
62
+ outputs = gr.Textbox(label="Caption")
63
+
64
  title = "TRCaptionNet"
65
+ paper_link = "" # add if available
66
  github_link = "https://github.com/serdaryildiz/TRCaptionNetpp"
67
+ description = (
68
+ f"<p style='text-align: center'>"
69
+ f"<a href='{github_link}' target='_blank'>TRCaptionNet++</a>: "
70
+ f"A high-performance encoder–decoder based Turkish image captioning model "
71
+ f"fine-tuned with a large-scale pretrain dataset.</p>"
72
+ )
73
+ article = (
74
+ f"<p style='text-align: center'>"
75
+ f"<a href='{paper_link}' target='_blank'>Paper</a> | "
76
+ f"<a href='{github_link}' target='_blank'>Github Repo</a></p>"
77
+ )
78
  css = ".output-image, .input-image, .image-preview {height: 600px !important}"
79
 
80
+ # Build examples with full rows (image, min_length, repetition_penalty)
81
+ imgs = glob.glob("images/*")
82
+ if imgs:
83
+ examples = [[p, 11, 2.0] for p in imgs]
84
+ cache_examples = True
85
+ else:
86
+ examples = None
87
+ cache_examples = False # avoid startup caching when there are no examples
88
+
89
+ iface = gr.Interface(
90
+ fn=inference,
91
+ inputs=[img_input, minlen_slider, rep_slider],
92
+ outputs=outputs,
93
+ title=title,
94
+ description=description,
95
+ examples=examples,
96
+ cache_examples=cache_examples,
97
+ article=article,
98
+ css=css,
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ # If you still hit caching issues, you can also set: ssr_mode=False
103
+ iface.launch(server_name="0.0.0.0", server_port=7860)