xinyu1205 commited on
Commit
4a16ef8
1 Parent(s): 8e6dc9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -34
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torchvision.transforms as transforms
6
 
7
  from PIL import Image
8
- from models.tag2text import tag2text_caption
9
 
10
  import gradio as gr
11
 
@@ -17,56 +17,79 @@ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
17
  std=[0.229, 0.224, 0.225])
18
  transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
19
 
 
 
 
 
 
 
 
 
20
 
21
  #######Swin Version
22
- pretrained = 'tag2text_swin_14m.pth'
23
 
24
- model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
25
 
26
- model.eval()
27
- model = model.to(device)
28
 
29
 
30
- def inference(raw_image, input_tag):
31
  raw_image = raw_image.resize((image_size, image_size))
32
 
33
- image = transform(raw_image).unsqueeze(0).to(device)
34
- model.threshold = 0.68
35
- if input_tag == '' or input_tag == 'none' or input_tag == 'None':
36
- input_tag_list = None
 
37
  else:
38
- input_tag_list = []
39
- input_tag_list.append(input_tag.replace(',',' | '))
40
- with torch.no_grad():
 
 
 
 
 
41
 
42
 
43
- caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
44
- if input_tag_list == None:
45
- tag_1 = tag_predict
46
- tag_2 = ['none']
47
- else:
48
- _, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
49
- tag_2 = tag_predict
 
 
 
50
 
51
- return tag_1[0],tag_2[0],caption[0]
 
 
 
 
 
 
 
52
 
 
53
 
54
- inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas)")]
 
55
 
56
- outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Specified Tags"), gr.outputs.Textbox(label="Image Caption") ]
57
 
58
- title = "Tag2Text"
59
- description = "Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption."
60
 
61
- article = "<p style='text-align: center'>Tag2text training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://arxiv.org/abs/2303.05657' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"
62
 
63
- demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000483108.jpg',"none"],
64
- ['images/COCO_val2014_000000483108.jpg',"power line"],
65
- ['images/COCO_val2014_000000483108.jpg',"track, train"] ,
66
- ['images/bdf391a6f4b1840a.jpg',"none"],
67
- ['images/64891_194270823.jpg',"none"],
68
- ['images/2800737_834897251.jpg',"none"],
69
- ['images/1641173_2291260800.jpg',"none"],
70
  ])
71
 
72
- demo.launch(enable_queue=True)
 
5
  import torchvision.transforms as transforms
6
 
7
  from PIL import Image
8
+ from models.tag2text import tag2text_caption, ram
9
 
10
  import gradio as gr
11
 
 
17
  std=[0.229, 0.224, 0.225])
18
  transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
19
 
20
+ #######Tag2Text Model
21
+ pretrained = '/home/notebook/data/group/huangxinyu/pretrain_model/tag2text/tag2text_swin_14m.pth'
22
+
23
+ model_tag2text = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
24
+
25
+ model_tag2text.eval()
26
+ model_tag2text = model_tag2text.to(device)
27
+
28
 
29
  #######Swin Version
30
+ pretrained = '/home/notebook/code/personal/S9049611/tag2text-v2/output/pretrain_tag2text_large_v2_14m_large_v14/new_coco_ori_finetune_384_v5_epoch03/checkpoint_01.pth'
31
 
32
+ model_ram = ram(pretrained=pretrained, image_size=image_size, vit='swin_l' )
33
 
34
+ model_ram.eval()
35
+ model_ram = model_ram.to(device)
36
 
37
 
38
+ def inference(raw_image, model_n , input_tag):
39
  raw_image = raw_image.resize((image_size, image_size))
40
 
41
+ image = transform(raw_image).unsqueeze(0).to(device)
42
+ if model_n == 'Recognize Anything Model':
43
+ model = model_ram
44
+ tags, tags_chinese = model.generate_tag(image)
45
+ return tags[0],tags_chinese[0], 'none'
46
  else:
47
+ model = model_tag2text
48
+ model.threshold = 0.68
49
+ if input_tag == '' or input_tag == 'none' or input_tag == 'None':
50
+ input_tag_list = None
51
+ else:
52
+ input_tag_list = []
53
+ input_tag_list.append(input_tag.replace(',',' | '))
54
+ with torch.no_grad():
55
 
56
 
57
+ caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
58
+ if input_tag_list == None:
59
+ tag_1 = tag_predict
60
+ tag_2 = ['none']
61
+ else:
62
+ _, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
63
+ tag_2 = tag_predict
64
+
65
+ return tag_1[0],'none',caption[0]
66
+
67
 
68
+ inputs = [
69
+ gr.inputs.Image(type='pil'),
70
+ gr.inputs.Radio(choices=['Recognize Anything Model',"Tag2Text Model"],
71
+ type="value",
72
+ default="Recognize Anything Model",
73
+ label="Model" ),
74
+ gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas, Currently only Tag2Text is supported)")
75
+ ]
76
 
77
+ outputs = [gr.outputs.Textbox(label="Tags"),gr.outputs.Textbox(label="标签"), gr.outputs.Textbox(label="Caption (currently only Tag2Text is supported)")]
78
 
79
+ # title = "Recognize Anything Model"
80
+ title = "<font size='10'> Recognize Anything Model</font>"
81
 
82
+ description = "Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo! <li><b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese outputs of the image tags</b>!</li><li><b>Tag2Text Model:</b> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption.</li> "
83
 
 
 
84
 
85
+ article = "<p style='text-align: center'>RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a> | <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"
86
 
87
+ demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[
88
+ ['images/demo1.jpg',"Recognize Anything Model","none"],
89
+ ['images/demo2.jpg',"Recognize Anything Model","none"],
90
+ ['images/demo4.jpg',"Recognize Anything Model","none"],
91
+ ['images/demo4.jpg',"Tag2Text Model","power line"],
92
+ ['images/demo4.jpg',"Tag2Text Model","track, train"] ,
 
93
  ])
94
 
95
+ demo.launch(enable_queue=True)