import argparse import collections import gzip import html import json import os import random import re import torch from tqdm import tqdm import numpy as np from utils import * from PIL import Image import requests from transformers import AutoProcessor, MllamaForConditionalGeneration, MllamaForCausalLM, MllamaTextModel, MllamaVisionModel def load_data(args): # Musical_Instruments.features.json # item2feature_path = args.data_path item2feature = load_json(args.data_path) # item2order_path = args.order_path # item2order = load_json(args.order_path) return item2feature def generate_feature(item2feature, features): item_feature_list = [] for item in item2feature: data = item2feature[item] text = [] image = [] for meta_key in features: if meta_key in data: if 'image' in meta_key: image.append(data[meta_key][0]) else: meta_value = clean_text(data[meta_key]) text.append(meta_value.strip()) item_feature_list.append([item, text, image]) return item_feature_list def preprocess_feature(args): print('Process feature data ...') # print('Dataset: ', args.dataset) item2feature = load_data(args) # load item text and clean item_feature_list = generate_feature(item2feature, ['title', 'description', 'imageH']) # item_text_list = generate_text(item2feature, ['title']) # return: list of (item_ID, cleaned_item_text) return item_feature_list def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1, save_path = ''): print('Generate feature embedding ...') # print(' Dataset: ', args.dataset) items, texts, images = zip(*item_text_list) order_texts, order_images = [[0]] * len(items), [[0]] * len(items) # item_order_mapping = {} # item_order = 0 # for item in items: # item_order_mapping[item] = item_order # item_order += 1 item_order_mapping = load_json(args.order_path) for item, text, image in zip(items, texts, images): order_texts[int(item_order_mapping[item])] = text order_images[int(item_order_mapping[item])] = image for text in order_texts: assert text != [0] for image in order_images: assert image != [0] embeddings = [] text_emb_result = [] image_emb_result = [] multi_modal_emb_result = [] start, batch_size = 0, 1 with torch.no_grad(): while start < len(order_texts): if (start + 1) % 100 == 0: print("==>", start + 1) item_text, item_image = ' '.join(order_texts[start: start + 1][0]), order_images[start: start + 1][0][0] # print(field_texts) # field_texts = zip(*field_texts) processed_text = processor(text = item_text, return_tensors = 'pt').to(args.device) text_output = model.language_model.model(**processed_text) text_masked_output = text_output.last_hidden_state * processed_text['attention_mask'].unsqueeze(-1) text_mean_output = text_masked_output.sum(dim = 1) / processed_text['attention_mask'].sum(dim = -1, keepdim = True) text_mean_output = text_mean_output.detach().cpu() text_emb_result.append(text_mean_output.numpy().tolist()) open_image = Image.open(requests.get(item_image, stream = True).raw) processed_image = processor(images = open_image, return_tensors = "pt").to(args.device) image_output = model.vision_model(**processed_image) image_mean_output = image_output.last_hidden_state.squeeze().mean(dim = 0) image_mean_output = image_mean_output.mean(dim = 0,keepdim = True) image_mean_output = image_mean_output.detach().cpu() image_emb_result.append(image_mean_output.numpy().tolist()) prompt = '<|image|>' + item_text inputs = processor(text = prompt, images = open_image, return_tensors = "pt").to(args.device) multi_modal_output = model(**inputs, output_hidden_states = True) multi_modal_mean_output = multi_modal_output.hidden_states[-1].mean(dim = 1) multi_modal_mean_output = multi_modal_mean_output.detach().cpu() multi_modal_emb_result.append(multi_modal_mean_output.numpy().tolist()) text_embeddings = torch.cat(text_emb_result, dim = 0).numpy() print('Text-Embeddings shape: ', text_embeddings.shape) image_embeddings = torch.cat(image_emb_result, dim = 0).numpy() print('Image-Embeddings shape: ', image_embeddings.shape) multi_modal_embeddings = torch.cat(multi_modal_emb_result, dim = 0).numpy() print('Multimodal-Embeddings shape: ', multi_modal_embeddings.shape) file = os.path.join(args.save_path + "Musical_Instruments.emb.text.npy") np.save(file, text_embeddings) file = os.path.join(args.save_path + "Musical_Instruments.emb.imgae.npy") np.save(file, image_embeddings) file = os.path.join(args.save_path + "Musical_Instruments.emb.multimodal.npy") np.save(file, multi_modal_embeddings) # tokenized_text = tokenizer( # item_text, # max_length = args.max_sent_len, # truncation = True, # return_tensors = 'pt', # padding = "longest").to(args.device) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--gpu_id', type=int, default=0, help='ID of running GPU') parser.add_argument('--plm_name', type=str, default='llama') parser.add_argument('--plm_checkpoint', type=str, default='') parser.add_argument('--max_sent_len', type=int, default=2048) parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default') parser.add_argument('--data_path', type=str, default='') parser.add_argument('--order_path', type=str, default='') parser.add_argument('--save_path', type=str, default='') return parser.parse_args() if __name__ == '__main__': args = parse_args() device = set_device(args.gpu_id) args.device = device item_feature_list = preprocess_feature(args) model = MllamaForConditionalGeneration.from_pretrained(args.plm_checkpoint, torch_dtype = torch.float16) processor = AutoProcessor.from_pretrained(args.plm_checkpoint) # plm_tokenizer, plm_model = load_plm(args.plm_checkpoint) # if processor.pad_token_id is None: # processor.pad_token_id = 0 model = model.to(device) generate_item_embedding( args, item_feature_list, processor, model, word_drop_ratio = args.word_drop_ratio, save_path = args.save_path )