|
import argparse
|
|
import collections
|
|
import gzip
|
|
import html
|
|
import json
|
|
import os
|
|
import random
|
|
import re
|
|
import torch
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
from utils import *
|
|
from PIL import Image
|
|
import requests
|
|
from transformers import AutoProcessor, MllamaForConditionalGeneration, MllamaForCausalLM, MllamaTextModel, MllamaVisionModel
|
|
|
|
def load_data(args):
|
|
|
|
|
|
item2feature = load_json(args.data_path)
|
|
|
|
|
|
|
|
return item2feature
|
|
|
|
def generate_feature(item2feature, features):
|
|
item_feature_list = []
|
|
|
|
for item in item2feature:
|
|
data = item2feature[item]
|
|
text = []
|
|
image = []
|
|
for meta_key in features:
|
|
if meta_key in data:
|
|
if 'image' in meta_key:
|
|
image.append(data[meta_key][0])
|
|
else:
|
|
meta_value = clean_text(data[meta_key])
|
|
text.append(meta_value.strip())
|
|
|
|
item_feature_list.append([item, text, image])
|
|
|
|
return item_feature_list
|
|
|
|
def preprocess_feature(args):
|
|
print('Process feature data ...')
|
|
|
|
|
|
item2feature = load_data(args)
|
|
|
|
item_feature_list = generate_feature(item2feature, ['title', 'description', 'imageH'])
|
|
|
|
|
|
return item_feature_list
|
|
|
|
def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1, save_path = ''):
|
|
print('Generate feature embedding ...')
|
|
|
|
|
|
items, texts, images = zip(*item_text_list)
|
|
order_texts, order_images = [[0]] * len(items), [[0]] * len(items)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
item_order_mapping = load_json(args.order_path)
|
|
|
|
for item, text, image in zip(items, texts, images):
|
|
order_texts[int(item_order_mapping[item])] = text
|
|
order_images[int(item_order_mapping[item])] = image
|
|
for text in order_texts:
|
|
assert text != [0]
|
|
for image in order_images:
|
|
assert image != [0]
|
|
|
|
embeddings = []
|
|
text_emb_result = []
|
|
image_emb_result = []
|
|
multi_modal_emb_result = []
|
|
start, batch_size = 0, 1
|
|
with torch.no_grad():
|
|
while start < len(order_texts):
|
|
if (start + 1) % 100 == 0:
|
|
print("==>", start + 1)
|
|
|
|
item_text, item_image = ' '.join(order_texts[start: start + 1][0]), order_images[start: start + 1][0][0]
|
|
|
|
|
|
processed_text = processor(text = item_text, return_tensors = 'pt').to(args.device)
|
|
text_output = model.language_model.model(**processed_text)
|
|
text_masked_output = text_output.last_hidden_state * processed_text['attention_mask'].unsqueeze(-1)
|
|
text_mean_output = text_masked_output.sum(dim = 1) / processed_text['attention_mask'].sum(dim = -1, keepdim = True)
|
|
text_mean_output = text_mean_output.detach().cpu()
|
|
text_emb_result.append(text_mean_output.numpy().tolist())
|
|
|
|
open_image = Image.open(requests.get(item_image, stream = True).raw)
|
|
processed_image = processor(images = open_image, return_tensors = "pt").to(args.device)
|
|
image_output = model.vision_model(**processed_image)
|
|
image_mean_output = image_output.last_hidden_state.squeeze().mean(dim = 0)
|
|
image_mean_output = image_mean_output.mean(dim = 0,keepdim = True)
|
|
image_mean_output = image_mean_output.detach().cpu()
|
|
image_emb_result.append(image_mean_output.numpy().tolist())
|
|
|
|
prompt = '<|image|>' + item_text
|
|
inputs = processor(text = prompt, images = open_image, return_tensors = "pt").to(args.device)
|
|
multi_modal_output = model(**inputs, output_hidden_states = True)
|
|
multi_modal_mean_output = multi_modal_output.hidden_states[-1].mean(dim = 1)
|
|
multi_modal_mean_output = multi_modal_mean_output.detach().cpu()
|
|
multi_modal_emb_result.append(multi_modal_mean_output.numpy().tolist())
|
|
|
|
text_embeddings = torch.cat(text_emb_result, dim = 0).numpy()
|
|
print('Text-Embeddings shape: ', text_embeddings.shape)
|
|
image_embeddings = torch.cat(image_emb_result, dim = 0).numpy()
|
|
print('Image-Embeddings shape: ', image_embeddings.shape)
|
|
multi_modal_embeddings = torch.cat(multi_modal_emb_result, dim = 0).numpy()
|
|
print('Multimodal-Embeddings shape: ', multi_modal_embeddings.shape)
|
|
|
|
file = os.path.join(args.save_path + "Musical_Instruments.emb.text.npy")
|
|
np.save(file, text_embeddings)
|
|
file = os.path.join(args.save_path + "Musical_Instruments.emb.imgae.npy")
|
|
np.save(file, image_embeddings)
|
|
file = os.path.join(args.save_path + "Musical_Instruments.emb.multimodal.npy")
|
|
np.save(file, multi_modal_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--gpu_id', type=int, default=0, help='ID of running GPU')
|
|
parser.add_argument('--plm_name', type=str, default='llama')
|
|
parser.add_argument('--plm_checkpoint', type=str, default='')
|
|
parser.add_argument('--max_sent_len', type=int, default=2048)
|
|
parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
|
|
parser.add_argument('--data_path', type=str, default='')
|
|
parser.add_argument('--order_path', type=str, default='')
|
|
parser.add_argument('--save_path', type=str, default='')
|
|
return parser.parse_args()
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
device = set_device(args.gpu_id)
|
|
args.device = device
|
|
|
|
item_feature_list = preprocess_feature(args)
|
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(args.plm_checkpoint, torch_dtype = torch.float16)
|
|
processor = AutoProcessor.from_pretrained(args.plm_checkpoint)
|
|
|
|
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
generate_item_embedding(
|
|
args,
|
|
item_feature_list,
|
|
processor,
|
|
model,
|
|
word_drop_ratio = args.word_drop_ratio,
|
|
save_path = args.save_path
|
|
) |