mmir_usersim / app.py
yashonwu's picture
Update app.py
2c2ebb9
import torch
# usersim_path_shoes = "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/caption_model_shoes"
# usersim_path_dresses = "http://www.dcs.gla.ac.uk/~craigm/fcrs/captioners/dresses_cap_caption_models"
drive_path = 'mmir_usersim_resources/'
data_type= ["shoes", "dresses", "shirts", "tops&tees"]
usersim_path_shoes = drive_path + "checkpoints_usersim/shoes"
usersim_path_dresses = drive_path + "checkpoints_usersim/dresses"
usersim_path_shirts = drive_path + "checkpoints_usersim/shirts"
usersim_path_topstees = drive_path + "checkpoints_usersim/topstees"
usersim_path = [usersim_path_shoes, usersim_path_dresses, usersim_path_shirts, usersim_path_topstees]
import captioning.captioner as captioner
image_feat_params = {'model':'resnet101','model_root':drive_path + 'imagenet_weights','att_size':7}
# image_feat_params = {'model':'resnet101','model_root':'','att_size':7}
captioner_relative_shoes = captioner.Captioner(is_relative= True, model_path= usersim_path[0], image_feat_params=image_feat_params, data_type=data_type[0], load_resnet=True)
captioner_relative_dresses = captioner.Captioner(is_relative= True, model_path= usersim_path[1], image_feat_params=image_feat_params, data_type=data_type[1], load_resnet=True)
captioner_relative_shirts = captioner.Captioner(is_relative= True, model_path= usersim_path[2], image_feat_params=image_feat_params, data_type=data_type[2], load_resnet=True)
captioner_relative_topstees = captioner.Captioner(is_relative= True, model_path= usersim_path[3], image_feat_params=image_feat_params, data_type=data_type[3], load_resnet=True)
def generate_sentence_shoes(image_path_1, image_path_2):
fc_feat, att_feat = captioner_relative_shoes.get_img_feat(image_path_1)
fc_feat_ref, att_feat_ref = captioner_relative_shoes.get_img_feat(image_path_2)
fc_feat = torch.unsqueeze(fc_feat, dim=0)
att_feat = torch.unsqueeze(att_feat, dim=0)
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
seq, sents = captioner_relative_shoes.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
sentence = sents[0]
return sentence
def generate_sentence_dresses(image_path_1, image_path_2):
fc_feat, att_feat = captioner_relative_dresses.get_img_feat(image_path_1)
fc_feat_ref, att_feat_ref = captioner_relative_dresses.get_img_feat(image_path_2)
fc_feat = torch.unsqueeze(fc_feat, dim=0)
att_feat = torch.unsqueeze(att_feat, dim=0)
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
seq, sents = captioner_relative_dresses.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
sentence = sents[0]
return sentence
def generate_sentence_shirts(image_path_1, image_path_2):
fc_feat, att_feat = captioner_relative_shirts.get_img_feat(image_path_1)
fc_feat_ref, att_feat_ref = captioner_relative_shirts.get_img_feat(image_path_2)
fc_feat = torch.unsqueeze(fc_feat, dim=0)
att_feat = torch.unsqueeze(att_feat, dim=0)
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
seq, sents = captioner_relative_shirts.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
sentence = sents[0]
return sentence
def generate_sentence_topstees(image_path_1, image_path_2):
fc_feat, att_feat = captioner_relative_topstees.get_img_feat(image_path_1)
fc_feat_ref, att_feat_ref = captioner_relative_topstees.get_img_feat(image_path_2)
fc_feat = torch.unsqueeze(fc_feat, dim=0)
att_feat = torch.unsqueeze(att_feat, dim=0)
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0)
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0)
seq, sents = captioner_relative_topstees.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref))
sentence = sents[0]
return sentence
import numpy as np
import gradio as gr
examples_shoes = [["images/shoes/img_womens_athletic_shoes_1223.jpg", "images/shoes/img_womens_athletic_shoes_830.jpg"],
["images/shoes/img_womens_athletic_shoes_830.jpg", "images/shoes/img_womens_athletic_shoes_1223.jpg"],
["images/shoes/img_womens_high_heels_559.jpg", "images/shoes/img_womens_high_heels_690.jpg"],
["images/shoes/img_womens_high_heels_690.jpg", "images/shoes/img_womens_high_heels_559.jpg"]]
examples_dresses = [["images/dresses/B007UZSPC8.jpg", "images/dresses/B006MPVW4U.jpg"],
["images/dresses/B005KMQQFQ.jpg", "images/dresses/B005QYY5W4.jpg"],
["images/dresses/B005OBAGD6.jpg", "images/dresses/B006U07GW4.jpg"],
["images/dresses/B0047Y0K0U.jpg", "images/dresses/B006TAM4CW.jpg"]]
examples_shirts = [["images/shirts/B00305G9I4.jpg", "images/shirts/B005BLUUJY.jpg"],
["images/shirts/B004WSVYX8.jpg", "images/shirts/B008TP27PY.jpg"],
["images/shirts/B003INE0Q6.jpg", "images/shirts/B0051D0X2Q.jpg"],
["images/shirts/B00EZUKCCM.jpg", "images/shirts/B00B88ZKXA.jpg"]]
examples_topstees = [["images/topstees/B0082993AO.jpg", "images/topstees/B008293HO2.jpg"],
["images/topstees/B006YN4J2C.jpg", "images/topstees/B0035EPUBW.jpg"],
["images/topstees/B00B5SKOMU.jpg", "images/topstees/B004H3XMYM.jpg"],
["images/topstees/B008DVXGO0.jpg", "images/topstees/B008JYNN30.jpg"]
]
with gr.Blocks() as demo:
gr.Markdown("Relative Captioning for Fashion.")
with gr.Tab("Shoes"):
with gr.Row():
target_shoes = gr.Image(source="upload", type="filepath", label="Target Image")
candidate_shoes = gr.Image(source="upload", type="filepath", label="Candidate Image")
output_text_shoes = gr.Textbox(label="Generated Sentence")
shoes_btn = gr.Button("Generate")
gr.Examples(examples_shoes, inputs=[target_shoes, candidate_shoes])
with gr.Tab("Dresses"):
with gr.Row():
target_dresses = gr.Image(source="upload", type="filepath", label="Target Image")
candidate_dresses = gr.Image(source="upload", type="filepath", label="Candidate Image")
output_text_dresses = gr.Textbox(label="Generated Sentence")
dresses_btn = gr.Button("Generate")
gr.Examples(examples_dresses, inputs=[target_dresses, candidate_dresses])
with gr.Tab("Shirts"):
with gr.Row():
target_shirts = gr.Image(source="upload", type="filepath", label="Target Image")
candidate_shirts = gr.Image(source="upload", type="filepath", label="Candidate Image")
output_text_shirts = gr.Textbox(label="Generated Sentence")
shirts_btn = gr.Button("Generate")
gr.Examples(examples_shirts, inputs=[target_shirts, candidate_shirts])
with gr.Tab("Tops&Tees"):
with gr.Row():
target_topstees = gr.Image(source="upload", type="filepath", label="Target Image")
candidate_topstees = gr.Image(source="upload", type="filepath", label="Candidate Image")
output_text_topstees = gr.Textbox(label="Generated Sentence")
topstees_btn = gr.Button("Generate")
gr.Examples(examples_topstees, inputs=[target_topstees, candidate_topstees])
shoes_btn.click(generate_sentence_shoes, inputs=[target_shoes, candidate_shoes], outputs=output_text_shoes)
dresses_btn.click(generate_sentence_dresses, inputs=[target_dresses, candidate_dresses], outputs=output_text_dresses)
shirts_btn.click(generate_sentence_shirts, inputs=[target_shirts, candidate_shirts], outputs=output_text_shirts)
topstees_btn.click(generate_sentence_topstees, inputs=[target_topstees, candidate_topstees], outputs=output_text_topstees)
demo.queue(concurrency_count=3)
demo.launch()