Yiyang Zhou
update
8c9c964
from utils import *
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava import conversation as conversation_lib
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
from llava.model import *
from transformers import AutoProcessor, CLIPModel
from accelerate.utils import gather_object
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import AutoTokenizer
import json
from accelerate import Accelerator
from PIL import Image
import torch.nn.functional as F
import os
import pickle
def get_done_ids(file_path):
# 初始化一个空列表来存储 id 值
id_list = []
# 打开文件,并逐行读取
with open(file_path, 'r') as file:
for line in file:
# 将每一行的内容从 JSON 字符串转换为字典
# print(line)
data = json.loads(line)
# 将字典中 'id' 键的值添加到列表中
if 'id' in data:
id_list.append(data['id'])
return(id_list)
def save_object(obj, file_path):
"""保存对象到指定的Pickle文件."""
with open(file_path, 'wb') as file: # 打开文件以二进制写入模式
pickle.dump(obj, file) # 使用pickle的dump方法将对象序列化到文件
def get_prompts(inputs):
input_questions = [DEFAULT_IMAGE_TOKEN + '\n' + input_question for input_question in inputs]
prompts = []
for input_q in input_questions:
conv = conv_templates['v1'].copy()
conv.append_message(conv.roles[0], input_q)
conv.append_message(conv.roles[1], None)
prompts.append(conv.get_prompt())
return prompts
def get_file_names(directory):
"""返回指定目录下所有文件的名称列表(不包括子目录)。"""
file_names = [] # 创建一个空列表来存储文件名
for item in os.listdir(directory): # 遍历目录中的所有项
full_path = os.path.join(directory, item) # 获取项的完整路径
if os.path.isfile(full_path): # 检查这个路径是否是文件
file_names.append(item) # 如果是文件,则添加其名称到列表
return file_names
class Node:
def __init__(self, text, score, depth, parent=None, is_final=False):
self.text = text
self.score = score
self.depth = depth
self.parent = parent
self.children = []
self.is_final = is_final
def add_child(self, child):
self.children.append(child)
def print_paths(node, path=[]):
"""
递归函数,用于遍历树并打印从根节点到每个叶子节点的路径。
Args:
node: 当前节点。
path: 从根节点到当前节点的路径列表。
"""
# 将当前节点添加到路径中
path.append(f"{node.text} (Score: {node.score}, Final: {node.is_final})")
# 如果当前节点是叶子节点,打印路径
if not node.children: # 叶子节点没有子节点
print(" -> ".join(path))
else:
# 否则,继续遍历子节点
for child in node.children:
print_paths(child, path.copy()) # 使用path.copy()以避免修改同一个列表
def sentence_level_beam_search_tree(qid, model, accelerator, processor, tokenizer, after_tokenizer, initial_text, images, sentence_end_id, max_length, max_new_tokens, num_beams, num_beam_group, token_level_beams, temperature, diversity_penalty):
"""
Args:
model: HF模型,包含一个generate方法。
tokenizer: 模型的分词器。
initial_text: 开始生成的初始文本。
images: 与文本一起使用的图像。
sentence_end_id: 句子结束标记的ID。
max_length: 生成文本的最大长度。
max_new_tokens: 每次生成的新token的最大数量。
num_beams: 在每一步使用的beam数量。
temperature: 生成温度。
"""
# 初始化
root = Node(initial_text, 0, 0)
active_nodes = [root] # 活跃节点列表,初始只有根节点
with torch.no_grad():
while active_nodes:
new_nodes = []
for node in active_nodes:
print(node.text)
inputs = processor(text=node.text, images=images, return_tensors="pt").to(model.device)
with torch.inference_mode():
# outputs = model.module.generate(
outputs = model.generate(
**inputs,
num_beams=token_level_beams,
eos_token_id=sentence_end_id,
num_beam_groups=num_beam_group,
diversity_penalty=diversity_penalty,
# stopping_criteria=[stopping_criteria],
# temperature=temperature,
pad_token_id=tokenizer.pad_token_id, # different models may have different pad_token_id
num_return_sequences=token_level_beams,
max_new_tokens=max_new_tokens,
output_scores=True, # must be True
return_dict_in_generate=True, # must be True, because we need the text scores
)
# 解码生成的文本
gen_sequences = outputs.sequences[:, inputs.input_ids.shape[-1]:]
gen_texts = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
for j, (text, score) in enumerate(zip(gen_texts, outputs.sequences_scores)):
new_score = node.score + score.item()
is_final = (tokenizer.eos_token_id in gen_sequences[j].tolist()) or (after_tokenizer.eos_token_id in gen_sequences[j].tolist() or len(tokenizer.decode(outputs.sequences[j]))>=max_length)
new_node = Node(text, new_score, node.depth + 1, node, is_final)
node.add_child(new_node)
if is_final: # 检查是否包含结束标记
pass
else:
new_nodes.append(new_node)
new_nodes.sort(key=lambda x: x.score, reverse=True)
if len(new_nodes)<num_beams:
active_nodes = new_nodes
else:
active_nodes = new_nodes[:int(num_beams/2)-1]+new_nodes[-int(num_beams/2):]
if not active_nodes:
break
return [{'id': qid, 'tree': root}]
def load_and_merge_models(model_folder_path):
# 初始化一个空的字典来保存合并的模型参数
merged_model_state_dict = {}
# 遍历文件夹中的每个模型文件
for model_file in os.listdir(model_folder_path):
if model_file.endswith('.bin'): # 只处理以 .bin 结尾的文件
file_path = os.path.join(model_folder_path, model_file)
# 使用 torch.load 加载模型
model_state_dict = torch.load(file_path, map_location='cpu')
# print(model_state_dict.keys())
# 合并模型的状态字典
for key, value in model_state_dict.items():
if key not in merged_model_state_dict:
merged_model_state_dict[key] = value
else:
# 如果需要其他的合并逻辑,可以在此实现
# 例如,将值相加、取平均等
pass
return merged_model_state_dict
def eval_model(args):
disable_torch_init()
accelerator = Accelerator()
# output_file = args.output_file
model_path = args.model_path
mapping_path=args.weight_mapping_path
with open(mapping_path, 'r', encoding='utf-8') as f1:
mapping_keys = json.load(f1)
# model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, device_map=4)
tokenizer=AutoTokenizer.from_pretrained("llava-hf/llava-1.5-13b-hf", use_fast=False, padding_side='left')
after_tokenizer=AutoTokenizer.from_pretrained(model_path)
# tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
# tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-13b-hf")
# processor.tokenizer=tokenizer
# clip_model = CLIPModel.from_pretrained(eval_model_path, torch_dtype=torch.float16)
# clip_processor = AutoProcessor.from_pretrained(eval_model_path)
with open(args.dataset_path, 'r', encoding='utf8')as fp:
my_dataset = json.load(fp) #detail+reasoning
llava_loader=get_llava_dataloader(my_dataset, 1)
# lava_loader, processor = accelerator.prepare(
# llava_loader, processor
# )
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-13b-hf", device_map='cpu', torch_dtype=torch.float16)
state_dicts = load_and_merge_models(model_path)
modified_weights = {}
for old_key, value in state_dicts.items():
new_key = mapping_keys.get(old_key, old_key) # 如果没有在映射表中找到旧键,则保持原键
modified_weights[new_key] = value
modified_weights['language_model.model.embed_tokens.weight'] = model.state_dict()['language_model.model.embed_tokens.weight']
modified_weights['language_model.lm_head.weight'] = model.state_dict()['language_model.lm_head.weight']
# state_dicts['model'] = modified_weights
model.load_state_dict(modified_weights, strict=True)
# torch.cuda.empty_cache()
# print(model)
model.to(accelerator.device)
llava_loader, processor= accelerator.prepare(
llava_loader, processor
)
output_dir=args.output_dir
havedone_list=get_file_names(output_dir)
# TODO: please add check here
is_ref=args.is_ref
if is_ref:
ref_json=args.ref_path
with open(ref_json, 'r') as file:
data = json.load(file)
id_list = [item['id'] for item in data]
with torch.no_grad():
for data in llava_loader:
input_questions = data['input']
input_questions = [q.replace("<image>\n", "").replace("\n<image>", "").replace("<image>", "") for q in input_questions]
image_paths=data['image']
qid=data['question_ids']
# print(qid)
images=[]
save_name=str(qid[0])+'.pkl'
# if save_name in havedone_list:
# continue
if is_ref and (str(qid[0]) not in id_list):
print('pass:', str(qid[0]))
continue
save_path = os.path.join(output_dir, save_name)
for image_path in image_paths:
images.append(Image.open(os.path.join(args.images_dir,'COCO_train2014_'+image_path)))
prompts=get_prompts(input_questions)
sentence_end_id=29889
max_length = args.max_length
token_level_beams = args.num_token_beams
temperature = args.temperature
max_new_tokens = args.max_new_tokens
diversity_penalty = args.diversity_penalty
num_beams=args.num_beams
num_beam_group=args.num_beam_group
result=gather_object(sentence_level_beam_search_tree(
qid[0],
model,
accelerator,
processor,
tokenizer,
after_tokenizer,
# clip_model,
# clip_processor,
prompts[0],
images[0],
sentence_end_id,
max_length,
max_new_tokens,
num_beams,
num_beam_group,
token_level_beams,
temperature,
diversity_penalty
))
# print(result)
# print_paths(result[0]['tree'])
# print(qid)
# print(len(result))
if accelerator.is_main_process:
for obj in result:
# print(obj['id'])
r_save_path = os.path.join(output_dir, str(obj['id'])+'.pkl')
print(r_save_path)
save_object(obj, r_save_path)
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="/home/yiyangai/Projects/dongjie/StepbyStep/llava_13b_dpoed/llava_merged_dpo_13b_1epoch_1iteration")
parser.add_argument("--dataset_path", type=str, default='/home/yiyangai/Projects/dongjie/LlaVa-Instruct-150k/LLaVA-Instruct-150K/my_dataset12k.json')
parser.add_argument("--images_dir", type=str, default="../LlaVa-Instruct-150k/data/train2014")
parser.add_argument("--output_dir", type=str, default="/home/yiyangai/Projects/dongjie/StepbyStep/Save_Folder/2024-5-9-after1dpo-13b")
parser.add_argument("--temperature", type=float, default=0.3)
parser.add_argument("--diversity_penalty", type=float, default=3.0)
parser.add_argument("--num_beams", type=int, default=5)
parser.add_argument("--num_beam_group", type=int, default=5)
parser.add_argument("--num_token_beams", type=int, default=5)
parser.add_argument("--max_length", type=int, default=1024)
parser.add_argument("--max_new_tokens", type=int, default=70)
parser.add_argument("--weight_mapping_path", type=str, default='/home/yiyangai/Projects/dongjie/5de42962e78a4485afa7a05120d78d88/key_mapping_13b.json')
parser.add_argument("--is_ref", type=bool, default=False)
parser.add_argument("--ref_path", type=str, default='/home/yiyangai/Projects/dongjie/StepbyStep/Save_Folder/4-26-dataset.json')
args = parser.parse_args()
eval_model(args)