yashonwu commited on
Commit
e53c320
1 Parent(s): 7774775

modify app.py

Browse files
Files changed (2) hide show
  1. .gitignore +129 -0
  2. app.py +137 -4
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
app.py CHANGED
@@ -1,7 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ import torch
2
+
3
+ # usersim_path_shoes = "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/caption_model_shoes"
4
+ # usersim_path_dresses = "http://www.dcs.gla.ac.uk/~craigm/fcrs/captioners/dresses_cap_caption_models"
5
+
6
+ drive_path = '/content/drive/MyDrive/Datasets/mmir_usersim_resources/'
7
+
8
+ data_type= ["shoes", "dresses", "shirts", "tops&tees"]
9
+
10
+ usersim_path_shoes = drive_path + "checkpoints_usersim/shoes"
11
+ usersim_path_dresses = drive_path + "checkpoints_usersim/dresses"
12
+ usersim_path_shirts = drive_path + "checkpoints_usersim/shirts"
13
+ usersim_path_topstees = drive_path + "checkpoints_usersim/topstees"
14
+ usersim_path = [usersim_path_shoes, usersim_path_dresses, usersim_path_shirts, usersim_path_topstees]
15
+
16
+ import captioning.captioner as captioner
17
+ image_feat_params = {'model':'resnet101','model_root':drive_path + 'imagenet_weights','att_size':7}
18
+ # image_feat_params = {'model':'resnet101','model_root':'','att_size':7}
19
+
20
+ captioner_relative_shoes = captioner.Captioner(is_relative= True, model_path= usersim_path[0], image_feat_params=image_feat_params, data_type=data_type[0], load_resnet=True)
21
+ captioner_relative_dresses = captioner.Captioner(is_relative= True, model_path= usersim_path[1], image_feat_params=image_feat_params, data_type=data_type[1], load_resnet=True)
22
+ captioner_relative_shirts = captioner.Captioner(is_relative= True, model_path= usersim_path[2], image_feat_params=image_feat_params, data_type=data_type[2], load_resnet=True)
23
+ captioner_relative_topstees = captioner.Captioner(is_relative= True, model_path= usersim_path[3], image_feat_params=image_feat_params, data_type=data_type[3], load_resnet=True)
24
+
25
+ def generate_sentence_shoes(image_path_1, image_path_2):
26
+ fc_feat, att_feat = captioner_relative_shoes.get_img_feat(image_path_1)
27
+ fc_feat_ref, att_feat_ref = captioner_relative_shoes.get_img_feat(image_path_2)
28
+
29
+ fc_feat = torch.unsqueeze(fc_feat, dim=0)
30
+ att_feat = torch.unsqueeze(att_feat, dim=0)
31
+ fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
32
+ att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
33
+
34
+ seq, sents = captioner_relative_shoes.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
35
+
36
+ sentence = sents[0]
37
+ return sentence
38
+
39
+ def generate_sentence_dresses(image_path_1, image_path_2):
40
+ fc_feat, att_feat = captioner_relative_dresses.get_img_feat(image_path_1)
41
+ fc_feat_ref, att_feat_ref = captioner_relative_dresses.get_img_feat(image_path_2)
42
+
43
+ fc_feat = torch.unsqueeze(fc_feat, dim=0)
44
+ att_feat = torch.unsqueeze(att_feat, dim=0)
45
+ fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
46
+ att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
47
+
48
+ seq, sents = captioner_relative_dresses.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
49
+
50
+ sentence = sents[0]
51
+ return sentence
52
+
53
+ def generate_sentence_shirts(image_path_1, image_path_2):
54
+ fc_feat, att_feat = captioner_relative_shirts.get_img_feat(image_path_1)
55
+ fc_feat_ref, att_feat_ref = captioner_relative_shirts.get_img_feat(image_path_2)
56
+
57
+ fc_feat = torch.unsqueeze(fc_feat, dim=0)
58
+ att_feat = torch.unsqueeze(att_feat, dim=0)
59
+ fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
60
+ att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
61
+
62
+ seq, sents = captioner_relative_shirts.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
63
+
64
+ sentence = sents[0]
65
+ return sentence
66
+
67
+ def generate_sentence_topstees(image_path_1, image_path_2):
68
+ fc_feat, att_feat = captioner_relative_topstees.get_img_feat(image_path_1)
69
+ fc_feat_ref, att_feat_ref = captioner_relative_topstees.get_img_feat(image_path_2)
70
+
71
+ fc_feat = torch.unsqueeze(fc_feat, dim=0)
72
+ att_feat = torch.unsqueeze(att_feat, dim=0)
73
+ fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
74
+ att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
75
+
76
+ seq, sents = captioner_relative_topstees.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
77
+
78
+ sentence = sents[0]
79
+ return sentence
80
+
81
+ import numpy as np
82
  import gradio as gr
83
 
84
+ examples_shoes = [["images/shoes/img_womens_athletic_shoes_1223.jpg", "images/shoes/img_womens_athletic_shoes_830.jpg"],
85
+ ["images/shoes/img_womens_athletic_shoes_830.jpg", "images/shoes/img_womens_athletic_shoes_1223.jpg"],
86
+ ["images/shoes/img_womens_high_heels_559.jpg", "images/shoes/img_womens_high_heels_690.jpg"],
87
+ ["images/shoes/img_womens_high_heels_690.jpg", "images/shoes/img_womens_high_heels_559.jpg"]]
88
+
89
+ examples_dresses = [["images/dresses/B007UZSPC8.jpg", "images/dresses/B006MPVW4U.jpg"],
90
+ ["images/dresses/B005KMQQFQ.jpg", "images/dresses/B005QYY5W4.jpg"],
91
+ ["images/dresses/B005OBAGD6.jpg", "images/dresses/B006U07GW4.jpg"],
92
+ ["images/dresses/B0047Y0K0U.jpg", "images/dresses/B006TAM4CW.jpg"]]
93
+ examples_shirts = [["images/shirts/B00305G9I4.jpg", "images/shirts/B005BLUUJY.jpg"],
94
+ ["images/shirts/B004WSVYX8.jpg", "images/shirts/B008TP27PY.jpg"],
95
+ ["images/shirts/B003INE0Q6.jpg", "images/shirts/B0051D0X2Q.jpg"],
96
+ ["images/shirts/B00EZUKCCM.jpg", "images/shirts/B00B88ZKXA.jpg"]]
97
+ examples_topstees = [["images/topstees/B0082993AO.jpg", "images/topstees/B008293HO2.jpg"],
98
+ ["images/topstees/B006YN4J2C.jpg", "images/topstees/B0035EPUBW.jpg"],
99
+ ["images/topstees/B00B5SKOMU.jpg", "images/topstees/B004H3XMYM.jpg"],
100
+ ["images/topstees/B008DVXGO0.jpg", "images/topstees/B008JYNN30.jpg"]
101
+ ]
102
+
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown("Relative Captioning for Fashion.")
105
+ with gr.Tab("Shoes"):
106
+ with gr.Row():
107
+ target_shoes = gr.Image(source="upload", type="filepath", label="Target Image")
108
+ candidate_shoes = gr.Image(source="upload", type="filepath", label="Candidate Image")
109
+ output_text_shoes = gr.Textbox(label="Generated Sentence")
110
+ shoes_btn = gr.Button("Generate")
111
+ gr.Examples(examples_shoes, inputs=[target_shoes, candidate_shoes])
112
+ with gr.Tab("Dresses"):
113
+ with gr.Row():
114
+ target_dresses = gr.Image(source="upload", type="filepath", label="Target Image")
115
+ candidate_dresses = gr.Image(source="upload", type="filepath", label="Candidate Image")
116
+ output_text_dresses = gr.Textbox(label="Generated Sentence")
117
+ dresses_btn = gr.Button("Generate")
118
+ gr.Examples(examples_dresses, inputs=[target_dresses, candidate_dresses])
119
+ with gr.Tab("Shirts"):
120
+ with gr.Row():
121
+ target_shirts = gr.Image(source="upload", type="filepath", label="Target Image")
122
+ candidate_shirts = gr.Image(source="upload", type="filepath", label="Candidate Image")
123
+ output_text_shirts = gr.Textbox(label="Generated Sentence")
124
+ shirts_btn = gr.Button("Generate")
125
+ gr.Examples(examples_shirts, inputs=[target_shirts, candidate_shirts])
126
+ with gr.Tab("Tops&Tees"):
127
+ with gr.Row():
128
+ target_topstees = gr.Image(source="upload", type="filepath", label="Target Image")
129
+ candidate_topstees = gr.Image(source="upload", type="filepath", label="Candidate Image")
130
+ output_text_topstees = gr.Textbox(label="Generated Sentence")
131
+ topstees_btn = gr.Button("Generate")
132
+ gr.Examples(examples_topstees, inputs=[target_topstees, candidate_topstees])
133
+
134
+ shoes_btn.click(generate_sentence_shoes, inputs=[target_shoes, candidate_shoes], outputs=output_text_shoes)
135
+ dresses_btn.click(generate_sentence_dresses, inputs=[target_dresses, candidate_dresses], outputs=output_text_dresses)
136
+ shirts_btn.click(generate_sentence_shirts, inputs=[target_shirts, candidate_shirts], outputs=output_text_shirts)
137
+ topstees_btn.click(generate_sentence_topstees, inputs=[target_topstees, candidate_topstees], outputs=output_text_topstees)
138
 
139
+ demo.queue(concurrency_count=3)
140
+ demo.launch()