|
from utils import * |
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.model.builder import load_pretrained_model |
|
from llava.utils import disable_torch_init |
|
from llava import conversation as conversation_lib |
|
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images |
|
from llava.model import * |
|
from transformers import AutoProcessor, CLIPModel |
|
|
|
from accelerate.utils import gather_object |
|
import torch |
|
from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
from transformers import AutoTokenizer |
|
import json |
|
from accelerate import Accelerator |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
import os |
|
import pickle |
|
|
|
def get_done_ids(file_path): |
|
|
|
id_list = [] |
|
|
|
|
|
with open(file_path, 'r') as file: |
|
for line in file: |
|
|
|
|
|
data = json.loads(line) |
|
|
|
if 'id' in data: |
|
id_list.append(data['id']) |
|
return(id_list) |
|
|
|
|
|
def save_object(obj, file_path): |
|
"""保存对象到指定的Pickle文件.""" |
|
with open(file_path, 'wb') as file: |
|
pickle.dump(obj, file) |
|
|
|
|
|
def get_prompts(inputs): |
|
input_questions = [DEFAULT_IMAGE_TOKEN + '\n' + input_question for input_question in inputs] |
|
|
|
prompts = [] |
|
for input_q in input_questions: |
|
conv = conv_templates['v1'].copy() |
|
conv.append_message(conv.roles[0], input_q) |
|
conv.append_message(conv.roles[1], None) |
|
prompts.append(conv.get_prompt()) |
|
return prompts |
|
|
|
def get_file_names(directory): |
|
"""返回指定目录下所有文件的名称列表(不包括子目录)。""" |
|
file_names = [] |
|
for item in os.listdir(directory): |
|
full_path = os.path.join(directory, item) |
|
if os.path.isfile(full_path): |
|
file_names.append(item) |
|
return file_names |
|
|
|
class Node: |
|
def __init__(self, text, score, depth, parent=None, is_final=False): |
|
self.text = text |
|
self.score = score |
|
self.depth = depth |
|
self.parent = parent |
|
self.children = [] |
|
self.is_final = is_final |
|
|
|
def add_child(self, child): |
|
self.children.append(child) |
|
|
|
def print_paths(node, path=[]): |
|
""" |
|
递归函数,用于遍历树并打印从根节点到每个叶子节点的路径。 |
|
Args: |
|
node: 当前节点。 |
|
path: 从根节点到当前节点的路径列表。 |
|
""" |
|
|
|
path.append(f"{node.text} (Score: {node.score}, Final: {node.is_final})") |
|
|
|
|
|
if not node.children: |
|
print(" -> ".join(path)) |
|
else: |
|
|
|
for child in node.children: |
|
print_paths(child, path.copy()) |
|
|
|
def sentence_level_beam_search_tree(qid, model, accelerator, processor, tokenizer, after_tokenizer, initial_text, images, sentence_end_id, max_length, max_new_tokens, num_beams, num_beam_group, token_level_beams, temperature, diversity_penalty): |
|
""" |
|
Args: |
|
model: HF模型,包含一个generate方法。 |
|
tokenizer: 模型的分词器。 |
|
initial_text: 开始生成的初始文本。 |
|
images: 与文本一起使用的图像。 |
|
sentence_end_id: 句子结束标记的ID。 |
|
max_length: 生成文本的最大长度。 |
|
max_new_tokens: 每次生成的新token的最大数量。 |
|
num_beams: 在每一步使用的beam数量。 |
|
temperature: 生成温度。 |
|
""" |
|
|
|
root = Node(initial_text, 0, 0) |
|
active_nodes = [root] |
|
with torch.no_grad(): |
|
while active_nodes: |
|
new_nodes = [] |
|
|
|
for node in active_nodes: |
|
print(node.text) |
|
inputs = processor(text=node.text, images=images, return_tensors="pt").to(model.device) |
|
|
|
with torch.inference_mode(): |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
num_beams=token_level_beams, |
|
eos_token_id=sentence_end_id, |
|
num_beam_groups=num_beam_group, |
|
diversity_penalty=diversity_penalty, |
|
|
|
|
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=token_level_beams, |
|
max_new_tokens=max_new_tokens, |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
) |
|
|
|
|
|
gen_sequences = outputs.sequences[:, inputs.input_ids.shape[-1]:] |
|
gen_texts = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) |
|
for j, (text, score) in enumerate(zip(gen_texts, outputs.sequences_scores)): |
|
new_score = node.score + score.item() |
|
is_final = (tokenizer.eos_token_id in gen_sequences[j].tolist()) or (after_tokenizer.eos_token_id in gen_sequences[j].tolist() or len(tokenizer.decode(outputs.sequences[j]))>=max_length) |
|
new_node = Node(text, new_score, node.depth + 1, node, is_final) |
|
node.add_child(new_node) |
|
|
|
if is_final: |
|
pass |
|
else: |
|
new_nodes.append(new_node) |
|
|
|
new_nodes.sort(key=lambda x: x.score, reverse=True) |
|
|
|
if len(new_nodes)<num_beams: |
|
active_nodes = new_nodes |
|
else: |
|
active_nodes = new_nodes[:int(num_beams/2)-1]+new_nodes[-int(num_beams/2):] |
|
|
|
if not active_nodes: |
|
break |
|
|
|
return [{'id': qid, 'tree': root}] |
|
|
|
|
|
def load_and_merge_models(model_folder_path): |
|
|
|
merged_model_state_dict = {} |
|
|
|
|
|
for model_file in os.listdir(model_folder_path): |
|
if model_file.endswith('.bin'): |
|
file_path = os.path.join(model_folder_path, model_file) |
|
|
|
|
|
model_state_dict = torch.load(file_path, map_location='cpu') |
|
|
|
|
|
for key, value in model_state_dict.items(): |
|
if key not in merged_model_state_dict: |
|
merged_model_state_dict[key] = value |
|
else: |
|
|
|
|
|
pass |
|
return merged_model_state_dict |
|
|
|
|
|
def eval_model(args): |
|
disable_torch_init() |
|
accelerator = Accelerator() |
|
|
|
|
|
|
|
|
|
model_path = args.model_path |
|
mapping_path=args.weight_mapping_path |
|
|
|
with open(mapping_path, 'r', encoding='utf-8') as f1: |
|
mapping_keys = json.load(f1) |
|
|
|
tokenizer=AutoTokenizer.from_pretrained("llava-hf/llava-1.5-13b-hf", use_fast=False, padding_side='left') |
|
after_tokenizer=AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-13b-hf") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(args.dataset_path, 'r', encoding='utf8')as fp: |
|
my_dataset = json.load(fp) |
|
|
|
llava_loader=get_llava_dataloader(my_dataset, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-13b-hf", device_map='cpu', torch_dtype=torch.float16) |
|
state_dicts = load_and_merge_models(model_path) |
|
modified_weights = {} |
|
for old_key, value in state_dicts.items(): |
|
new_key = mapping_keys.get(old_key, old_key) |
|
modified_weights[new_key] = value |
|
|
|
modified_weights['language_model.model.embed_tokens.weight'] = model.state_dict()['language_model.model.embed_tokens.weight'] |
|
modified_weights['language_model.lm_head.weight'] = model.state_dict()['language_model.lm_head.weight'] |
|
|
|
|
|
model.load_state_dict(modified_weights, strict=True) |
|
|
|
|
|
model.to(accelerator.device) |
|
|
|
llava_loader, processor= accelerator.prepare( |
|
llava_loader, processor |
|
) |
|
output_dir=args.output_dir |
|
havedone_list=get_file_names(output_dir) |
|
|
|
is_ref=args.is_ref |
|
if is_ref: |
|
ref_json=args.ref_path |
|
with open(ref_json, 'r') as file: |
|
data = json.load(file) |
|
id_list = [item['id'] for item in data] |
|
|
|
with torch.no_grad(): |
|
for data in llava_loader: |
|
input_questions = data['input'] |
|
input_questions = [q.replace("<image>\n", "").replace("\n<image>", "").replace("<image>", "") for q in input_questions] |
|
image_paths=data['image'] |
|
qid=data['question_ids'] |
|
|
|
images=[] |
|
|
|
save_name=str(qid[0])+'.pkl' |
|
|
|
|
|
|
|
if is_ref and (str(qid[0]) not in id_list): |
|
print('pass:', str(qid[0])) |
|
continue |
|
|
|
save_path = os.path.join(output_dir, save_name) |
|
|
|
for image_path in image_paths: |
|
images.append(Image.open(os.path.join(args.images_dir,'COCO_train2014_'+image_path))) |
|
|
|
prompts=get_prompts(input_questions) |
|
|
|
sentence_end_id=29889 |
|
max_length = args.max_length |
|
token_level_beams = args.num_token_beams |
|
temperature = args.temperature |
|
max_new_tokens = args.max_new_tokens |
|
diversity_penalty = args.diversity_penalty |
|
num_beams=args.num_beams |
|
num_beam_group=args.num_beam_group |
|
|
|
result=gather_object(sentence_level_beam_search_tree( |
|
qid[0], |
|
model, |
|
accelerator, |
|
processor, |
|
tokenizer, |
|
after_tokenizer, |
|
|
|
|
|
prompts[0], |
|
images[0], |
|
sentence_end_id, |
|
max_length, |
|
max_new_tokens, |
|
num_beams, |
|
num_beam_group, |
|
token_level_beams, |
|
temperature, |
|
diversity_penalty |
|
)) |
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
for obj in result: |
|
|
|
r_save_path = os.path.join(output_dir, str(obj['id'])+'.pkl') |
|
print(r_save_path) |
|
save_object(obj, r_save_path) |
|
|
|
torch.cuda.empty_cache() |
|
accelerator.wait_for_everyone() |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-path", type=str, default="/home/yiyangai/Projects/dongjie/StepbyStep/llava_13b_dpoed/llava_merged_dpo_13b_1epoch_1iteration") |
|
parser.add_argument("--dataset_path", type=str, default='/home/yiyangai/Projects/dongjie/LlaVa-Instruct-150k/LLaVA-Instruct-150K/my_dataset12k.json') |
|
parser.add_argument("--images_dir", type=str, default="../LlaVa-Instruct-150k/data/train2014") |
|
parser.add_argument("--output_dir", type=str, default="/home/yiyangai/Projects/dongjie/StepbyStep/Save_Folder/2024-5-9-after1dpo-13b") |
|
parser.add_argument("--temperature", type=float, default=0.3) |
|
parser.add_argument("--diversity_penalty", type=float, default=3.0) |
|
parser.add_argument("--num_beams", type=int, default=5) |
|
parser.add_argument("--num_beam_group", type=int, default=5) |
|
parser.add_argument("--num_token_beams", type=int, default=5) |
|
parser.add_argument("--max_length", type=int, default=1024) |
|
parser.add_argument("--max_new_tokens", type=int, default=70) |
|
parser.add_argument("--weight_mapping_path", type=str, default='/home/yiyangai/Projects/dongjie/5de42962e78a4485afa7a05120d78d88/key_mapping_13b.json') |
|
parser.add_argument("--is_ref", type=bool, default=False) |
|
parser.add_argument("--ref_path", type=str, default='/home/yiyangai/Projects/dongjie/StepbyStep/Save_Folder/4-26-dataset.json') |
|
args = parser.parse_args() |
|
|
|
eval_model(args) |
|
|
|
|
|
|
|
|
|
|