Benchmark-Multimodal / data_process /multimodal_emb.py
Junyin's picture
Add files using upload-large-folder tool
8a506a6 verified
import argparse
import collections
import gzip
import html
import json
import os
import random
import re
import torch
from tqdm import tqdm
import numpy as np
from utils import *
from PIL import Image
import requests
from transformers import AutoProcessor, MllamaForConditionalGeneration, MllamaForCausalLM, MllamaTextModel, MllamaVisionModel
def load_data(args):
# Musical_Instruments.features.json
# item2feature_path = args.data_path
item2feature = load_json(args.data_path)
# item2order_path = args.order_path
# item2order = load_json(args.order_path)
return item2feature
def generate_feature(item2feature, features):
item_feature_list = []
for item in item2feature:
data = item2feature[item]
text = []
image = []
for meta_key in features:
if meta_key in data:
if 'image' in meta_key:
image.append(data[meta_key][0])
else:
meta_value = clean_text(data[meta_key])
text.append(meta_value.strip())
item_feature_list.append([item, text, image])
return item_feature_list
def preprocess_feature(args):
print('Process feature data ...')
# print('Dataset: ', args.dataset)
item2feature = load_data(args)
# load item text and clean
item_feature_list = generate_feature(item2feature, ['title', 'description', 'imageH'])
# item_text_list = generate_text(item2feature, ['title'])
# return: list of (item_ID, cleaned_item_text)
return item_feature_list
def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1, save_path = ''):
print('Generate feature embedding ...')
# print(' Dataset: ', args.dataset)
items, texts, images = zip(*item_text_list)
order_texts, order_images = [[0]] * len(items), [[0]] * len(items)
# item_order_mapping = {}
# item_order = 0
# for item in items:
# item_order_mapping[item] = item_order
# item_order += 1
item_order_mapping = load_json(args.order_path)
for item, text, image in zip(items, texts, images):
order_texts[int(item_order_mapping[item])] = text
order_images[int(item_order_mapping[item])] = image
for text in order_texts:
assert text != [0]
for image in order_images:
assert image != [0]
embeddings = []
text_emb_result = []
image_emb_result = []
multi_modal_emb_result = []
start, batch_size = 0, 1
with torch.no_grad():
while start < len(order_texts):
if (start + 1) % 100 == 0:
print("==>", start + 1)
item_text, item_image = ' '.join(order_texts[start: start + 1][0]), order_images[start: start + 1][0][0]
# print(field_texts)
# field_texts = zip(*field_texts)
processed_text = processor(text = item_text, return_tensors = 'pt').to(args.device)
text_output = model.language_model.model(**processed_text)
text_masked_output = text_output.last_hidden_state * processed_text['attention_mask'].unsqueeze(-1)
text_mean_output = text_masked_output.sum(dim = 1) / processed_text['attention_mask'].sum(dim = -1, keepdim = True)
text_mean_output = text_mean_output.detach().cpu()
text_emb_result.append(text_mean_output.numpy().tolist())
open_image = Image.open(requests.get(item_image, stream = True).raw)
processed_image = processor(images = open_image, return_tensors = "pt").to(args.device)
image_output = model.vision_model(**processed_image)
image_mean_output = image_output.last_hidden_state.squeeze().mean(dim = 0)
image_mean_output = image_mean_output.mean(dim = 0,keepdim = True)
image_mean_output = image_mean_output.detach().cpu()
image_emb_result.append(image_mean_output.numpy().tolist())
prompt = '<|image|>' + item_text
inputs = processor(text = prompt, images = open_image, return_tensors = "pt").to(args.device)
multi_modal_output = model(**inputs, output_hidden_states = True)
multi_modal_mean_output = multi_modal_output.hidden_states[-1].mean(dim = 1)
multi_modal_mean_output = multi_modal_mean_output.detach().cpu()
multi_modal_emb_result.append(multi_modal_mean_output.numpy().tolist())
text_embeddings = torch.cat(text_emb_result, dim = 0).numpy()
print('Text-Embeddings shape: ', text_embeddings.shape)
image_embeddings = torch.cat(image_emb_result, dim = 0).numpy()
print('Image-Embeddings shape: ', image_embeddings.shape)
multi_modal_embeddings = torch.cat(multi_modal_emb_result, dim = 0).numpy()
print('Multimodal-Embeddings shape: ', multi_modal_embeddings.shape)
file = os.path.join(args.save_path + "Musical_Instruments.emb.text.npy")
np.save(file, text_embeddings)
file = os.path.join(args.save_path + "Musical_Instruments.emb.imgae.npy")
np.save(file, image_embeddings)
file = os.path.join(args.save_path + "Musical_Instruments.emb.multimodal.npy")
np.save(file, multi_modal_embeddings)
# tokenized_text = tokenizer(
# item_text,
# max_length = args.max_sent_len,
# truncation = True,
# return_tensors = 'pt',
# padding = "longest").to(args.device)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=0, help='ID of running GPU')
parser.add_argument('--plm_name', type=str, default='llama')
parser.add_argument('--plm_checkpoint', type=str, default='')
parser.add_argument('--max_sent_len', type=int, default=2048)
parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
parser.add_argument('--data_path', type=str, default='')
parser.add_argument('--order_path', type=str, default='')
parser.add_argument('--save_path', type=str, default='')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
device = set_device(args.gpu_id)
args.device = device
item_feature_list = preprocess_feature(args)
model = MllamaForConditionalGeneration.from_pretrained(args.plm_checkpoint, torch_dtype = torch.float16)
processor = AutoProcessor.from_pretrained(args.plm_checkpoint)
# plm_tokenizer, plm_model = load_plm(args.plm_checkpoint)
# if processor.pad_token_id is None:
# processor.pad_token_id = 0
model = model.to(device)
generate_item_embedding(
args,
item_feature_list,
processor,
model,
word_drop_ratio = args.word_drop_ratio,
save_path = args.save_path
)