serdaryildiz commited on
Commit
863c93d
·
verified ·
1 Parent(s): 08e7633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -11
app.py CHANGED
@@ -6,43 +6,74 @@ import torch
6
 
7
  from Model import TRCaptionNet, clip_transform
8
 
9
- model_ckpt = "./checkpoints/TRCaptionNet_L14_berturk_tasviret.pth"
10
 
11
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
- device = "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  preprocess = clip_transform(224)
15
  model = TRCaptionNet({
16
  "max_length": 35,
17
  "clip": "ViT-L/14",
18
- "bert": "bert.json",
19
  "proj": True,
20
  "proj_num_head": 16
21
  })
 
22
  model.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
23
  model = model.to(device)
24
  model.eval()
25
 
26
 
 
27
  def inference(raw_image, min_length, repetition_penalty):
 
 
 
28
  batch = preprocess(raw_image).unsqueeze(0).to(device)
29
  caption = model.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
30
- return caption
 
31
 
32
 
33
  inputs = [gr.Image(type='pil', interactive=True,),
34
- gr.Slider(minimum=6, maximum=22, value=11, label="MINIMUM CAPTION LENGTH", step=1),
35
  gr.Slider(minimum=1, maximum=2, value=1.6, label="REPETITION PENALTY")]
36
- outputs = gr.components.Textbox(label="Caption")
37
- title = "TRCaptionNet"
 
38
  paper_link = ""
39
  github_link = "https://github.com/serdaryildiz/TRCaptionNet"
40
- description = f"<p style='text-align: center'><a href='{github_link}' target='_blank'>TRCaptionNet</a> : A novel and accurate deep Turkish image captioning model with vision transformer based image encoders and deep linguistic text decoders"
 
 
 
 
41
  examples = [
42
  ["images/test1.jpg"],
43
  ["images/test2.jpg"],
44
  ["images/test3.jpg"],
45
- ["images/test4.jpg"]
 
 
 
 
 
 
 
46
  ]
47
  article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
48
  css = ".output-image, .input-image, .image-preview {height: 600px !important}"
@@ -56,4 +87,3 @@ iface = gr.Interface(fn=inference,
56
  article=article,
57
  css=css)
58
  iface.launch()
59
-
 
6
 
7
  from Model import TRCaptionNet, clip_transform
8
 
 
9
 
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ # device = "cpu"
13
+
14
+ preprocess_tasviret = clip_transform(336)
15
+ model_tasviret = TRCaptionNet({
16
+ "max_length": 35,
17
+ "clip": "ViT-L/14@336px",
18
+ "bert": "dbmdz/bert-base-turkish-cased",
19
+ "proj": True,
20
+ "proj_num_head": 16
21
+ })
22
+ model_ckpt = "./checkpoints/TRCaptionNet-TasvirEt_L14_334_berturk.pth"
23
+ model_tasviret.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
24
+ model_tasviret = model_tasviret.to(device)
25
+ model_tasviret.eval()
26
 
27
  preprocess = clip_transform(224)
28
  model = TRCaptionNet({
29
  "max_length": 35,
30
  "clip": "ViT-L/14",
31
+ "bert": "dbmdz/bert-base-turkish-cased",
32
  "proj": True,
33
  "proj_num_head": 16
34
  })
35
+ model_ckpt = "./checkpoints/TRCaptionNet_L14_berturk.pth"
36
  model.load_state_dict(torch.load(model_ckpt, map_location=device)["model"], strict=True)
37
  model = model.to(device)
38
  model.eval()
39
 
40
 
41
+
42
  def inference(raw_image, min_length, repetition_penalty):
43
+ batch = preprocess_tasviret(raw_image).unsqueeze(0).to(device)
44
+ caption_tasviret = model_tasviret.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
45
+
46
  batch = preprocess(raw_image).unsqueeze(0).to(device)
47
  caption = model.generate(batch, min_length=min_length, repetition_penalty=repetition_penalty)[0]
48
+
49
+ return [caption, caption_tasviret]
50
 
51
 
52
  inputs = [gr.Image(type='pil', interactive=True,),
53
+ gr.Slider(minimum=4, maximum=22, value=8, label="MINIMUM CAPTION LENGTH", step=1),
54
  gr.Slider(minimum=1, maximum=2, value=1.6, label="REPETITION PENALTY")]
55
+
56
+ outputs = [gr.components.Textbox(label="Caption"), gr.components.Textbox(label="Caption-TasvirEt")]
57
+ title = "TRCaptionNet-TasvirEt"
58
  paper_link = ""
59
  github_link = "https://github.com/serdaryildiz/TRCaptionNet"
60
+ IEEE_link = "https://github.com/serdaryildiz/TRCaptionNet"
61
+
62
+ description = f"<p style='text-align: center'><a href='{IEEE_link}' target='_blank'> SIU2024: Turkish Image Captioning with Vision Transformer Based Encoders and Text Decoders</a> "
63
+ description += f"<p style='text-align: center'><a href='{github_link}' target='_blank'>TRCaptionNet</a> : A novel and accurate deep Turkish image captioning model with vision transformer based image encoders and deep linguistic text decoders"
64
+
65
  examples = [
66
  ["images/test1.jpg"],
67
  ["images/test2.jpg"],
68
  ["images/test3.jpg"],
69
+ ["images/test4.jpg"],
70
+ ["images/test5.jpg"],
71
+ ["images/test6.jpg"],
72
+ ["images/test7.jpg"],
73
+ ["images/test8.jpg"],
74
+ ["images/test9.jpg"],
75
+ ["images/test10.jpg"],
76
+ ["images/test11.jpg"],
77
  ]
78
  article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
79
  css = ".output-image, .input-image, .image-preview {height: 600px !important}"
 
87
  article=article,
88
  css=css)
89
  iface.launch()