File size: 14,424 Bytes
8c9c964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from utils import *
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava import conversation as conversation_lib
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images
from llava.model import *
from transformers import AutoProcessor, CLIPModel

from accelerate.utils import gather_object
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import AutoTokenizer
import json
from accelerate import Accelerator
from PIL import Image
import torch.nn.functional as F
import os
import pickle

def get_done_ids(file_path):
    # 初始化一个空列表来存储 id 值
    id_list = []

    # 打开文件,并逐行读取
    with open(file_path, 'r') as file:
        for line in file:
            # 将每一行的内容从 JSON 字符串转换为字典
            # print(line)
            data = json.loads(line)
            # 将字典中 'id' 键的值添加到列表中
            if 'id' in data:
                id_list.append(data['id'])
    return(id_list)


def save_object(obj, file_path):
    """保存对象到指定的Pickle文件."""
    with open(file_path, 'wb') as file:  # 打开文件以二进制写入模式
        pickle.dump(obj, file)  # 使用pickle的dump方法将对象序列化到文件


def get_prompts(inputs):
    input_questions = [DEFAULT_IMAGE_TOKEN + '\n' + input_question for input_question in inputs]

    prompts = []
    for input_q in input_questions:
        conv = conv_templates['v1'].copy()
        conv.append_message(conv.roles[0], input_q)
        conv.append_message(conv.roles[1], None)
        prompts.append(conv.get_prompt())
    return prompts

def get_file_names(directory):
    """返回指定目录下所有文件的名称列表(不包括子目录)。"""
    file_names = []  # 创建一个空列表来存储文件名
    for item in os.listdir(directory):  # 遍历目录中的所有项
        full_path = os.path.join(directory, item)  # 获取项的完整路径
        if os.path.isfile(full_path):  # 检查这个路径是否是文件
            file_names.append(item)  # 如果是文件,则添加其名称到列表
    return file_names

class Node:
    def __init__(self, text, score, depth, parent=None, is_final=False):
        self.text = text
        self.score = score
        self.depth = depth
        self.parent = parent
        self.children = []
        self.is_final = is_final

    def add_child(self, child):
        self.children.append(child)

def print_paths(node, path=[]):
    """
    递归函数,用于遍历树并打印从根节点到每个叶子节点的路径。
    Args:
        node: 当前节点。
        path: 从根节点到当前节点的路径列表。
    """
    # 将当前节点添加到路径中
    path.append(f"{node.text} (Score: {node.score}, Final: {node.is_final})")
    
    # 如果当前节点是叶子节点,打印路径
    if not node.children:  # 叶子节点没有子节点
        print(" -> ".join(path))
    else:
        # 否则,继续遍历子节点
        for child in node.children:
            print_paths(child, path.copy())  # 使用path.copy()以避免修改同一个列表
            
def sentence_level_beam_search_tree(qid, model, accelerator, processor, tokenizer, after_tokenizer, initial_text, images,  sentence_end_id, max_length, max_new_tokens, num_beams, num_beam_group, token_level_beams, temperature, diversity_penalty):
    """
    Args:
        model: HF模型,包含一个generate方法。
        tokenizer: 模型的分词器。
        initial_text: 开始生成的初始文本。
        images: 与文本一起使用的图像。
        sentence_end_id: 句子结束标记的ID。
        max_length: 生成文本的最大长度。
        max_new_tokens: 每次生成的新token的最大数量。
        num_beams: 在每一步使用的beam数量。
        temperature: 生成温度。
    """
    # 初始化
    root = Node(initial_text, 0, 0)
    active_nodes = [root]  # 活跃节点列表,初始只有根节点
    with torch.no_grad():
        while active_nodes:
            new_nodes = []

            for node in active_nodes:
                print(node.text)
                inputs = processor(text=node.text, images=images, return_tensors="pt").to(model.device)

                with torch.inference_mode():
                    # outputs = model.module.generate(
                    outputs = model.generate(
                        **inputs,
                        num_beams=token_level_beams,
                        eos_token_id=sentence_end_id,
                        num_beam_groups=num_beam_group,
                        diversity_penalty=diversity_penalty,
                        # stopping_criteria=[stopping_criteria],
                        # temperature=temperature,
                        pad_token_id=tokenizer.pad_token_id, # different models may have different pad_token_id
                        num_return_sequences=token_level_beams,
                        max_new_tokens=max_new_tokens,
                        output_scores=True, # must be True
                        return_dict_in_generate=True, # must be True, because we need the text scores
                    )

                # 解码生成的文本
                gen_sequences = outputs.sequences[:, inputs.input_ids.shape[-1]:]
                gen_texts = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
                for j, (text, score) in enumerate(zip(gen_texts, outputs.sequences_scores)):
                    new_score = node.score + score.item()
                    is_final = (tokenizer.eos_token_id in gen_sequences[j].tolist()) or (after_tokenizer.eos_token_id in gen_sequences[j].tolist() or len(tokenizer.decode(outputs.sequences[j]))>=max_length)
                    new_node = Node(text, new_score, node.depth + 1, node, is_final)
                    node.add_child(new_node)
                
                    if is_final:  # 检查是否包含结束标记
                        pass
                    else:
                        new_nodes.append(new_node)

            new_nodes.sort(key=lambda x: x.score, reverse=True)
            
            if len(new_nodes)<num_beams:
                active_nodes = new_nodes
            else:
                active_nodes = new_nodes[:int(num_beams/2)-1]+new_nodes[-int(num_beams/2):]
                
            if not active_nodes:
                break

    return [{'id': qid, 'tree': root}]


def load_and_merge_models(model_folder_path):
    # 初始化一个空的字典来保存合并的模型参数
    merged_model_state_dict = {}

    # 遍历文件夹中的每个模型文件
    for model_file in os.listdir(model_folder_path):
        if model_file.endswith('.bin'):  # 只处理以 .bin 结尾的文件
            file_path = os.path.join(model_folder_path, model_file)
            
            # 使用 torch.load 加载模型
            model_state_dict = torch.load(file_path, map_location='cpu')
            # print(model_state_dict.keys())
            # 合并模型的状态字典
            for key, value in model_state_dict.items():
                if key not in merged_model_state_dict:
                    merged_model_state_dict[key] = value
                else:
                    # 如果需要其他的合并逻辑,可以在此实现
                    # 例如,将值相加、取平均等
                    pass
    return merged_model_state_dict

                
def eval_model(args):
    disable_torch_init()
    accelerator = Accelerator()
    
    
    
    # output_file = args.output_file
    model_path = args.model_path
    mapping_path=args.weight_mapping_path
    
    with open(mapping_path, 'r', encoding='utf-8') as f1:
        mapping_keys = json.load(f1)
    # model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, device_map=4)
    tokenizer=AutoTokenizer.from_pretrained("llava-hf/llava-1.5-13b-hf", use_fast=False, padding_side='left')
    after_tokenizer=AutoTokenizer.from_pretrained(model_path)
    # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    # tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
    
    processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-13b-hf")
    
    

    # processor.tokenizer=tokenizer
    # clip_model = CLIPModel.from_pretrained(eval_model_path, torch_dtype=torch.float16)
    # clip_processor = AutoProcessor.from_pretrained(eval_model_path)

    with open(args.dataset_path, 'r', encoding='utf8')as fp:
        my_dataset = json.load(fp) #detail+reasoning
        
    llava_loader=get_llava_dataloader(my_dataset, 1)


    # lava_loader, processor = accelerator.prepare(
    #     llava_loader, processor
    # )
    
    model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-13b-hf", device_map='cpu', torch_dtype=torch.float16)
    state_dicts = load_and_merge_models(model_path)
    modified_weights = {}
    for old_key, value in state_dicts.items():
        new_key = mapping_keys.get(old_key, old_key)  # 如果没有在映射表中找到旧键,则保持原键
        modified_weights[new_key] = value

    modified_weights['language_model.model.embed_tokens.weight'] = model.state_dict()['language_model.model.embed_tokens.weight']
    modified_weights['language_model.lm_head.weight'] = model.state_dict()['language_model.lm_head.weight']

    # state_dicts['model'] = modified_weights
    model.load_state_dict(modified_weights, strict=True)
    # torch.cuda.empty_cache()
    # print(model)
    model.to(accelerator.device)
    
    llava_loader, processor= accelerator.prepare(
        llava_loader, processor
    )
    output_dir=args.output_dir
    havedone_list=get_file_names(output_dir)
    # TODO: please add check here
    is_ref=args.is_ref
    if is_ref:
        ref_json=args.ref_path
        with open(ref_json, 'r') as file:
            data = json.load(file)
            id_list = [item['id'] for item in data]

    with torch.no_grad():
        for data in llava_loader:
            input_questions = data['input']
            input_questions = [q.replace("<image>\n", "").replace("\n<image>", "").replace("<image>", "") for q in input_questions]
            image_paths=data['image']
            qid=data['question_ids']
            # print(qid)
            images=[]
            
            save_name=str(qid[0])+'.pkl'
            # if save_name in havedone_list:
            #     continue
            
            if is_ref and (str(qid[0]) not in id_list):
                print('pass:', str(qid[0]))
                continue
            
            save_path = os.path.join(output_dir, save_name)
            
            for image_path in image_paths:
                images.append(Image.open(os.path.join(args.images_dir,'COCO_train2014_'+image_path)))
            
            prompts=get_prompts(input_questions)

            sentence_end_id=29889
            max_length = args.max_length
            token_level_beams = args.num_token_beams
            temperature = args.temperature
            max_new_tokens = args.max_new_tokens
            diversity_penalty = args.diversity_penalty
            num_beams=args.num_beams
            num_beam_group=args.num_beam_group
            
            result=gather_object(sentence_level_beam_search_tree(
                qid[0],
                model, 
                accelerator,
                processor,
                tokenizer,
                after_tokenizer,
                # clip_model, 
                # clip_processor, 
                prompts[0], 
                images[0],  
                sentence_end_id,  
                max_length, 
                max_new_tokens,
                num_beams,
                num_beam_group,
                token_level_beams,
                temperature,
                diversity_penalty
                ))
            # print(result)
            # print_paths(result[0]['tree'])
            # print(qid)
            # print(len(result))
            if accelerator.is_main_process:
                for obj in result:
                    # print(obj['id'])
                    r_save_path = os.path.join(output_dir, str(obj['id'])+'.pkl')
                    print(r_save_path)
                    save_object(obj, r_save_path)
                    
            torch.cuda.empty_cache()
            accelerator.wait_for_everyone()
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="/home/yiyangai/Projects/dongjie/StepbyStep/llava_13b_dpoed/llava_merged_dpo_13b_1epoch_1iteration")
    parser.add_argument("--dataset_path", type=str, default='/home/yiyangai/Projects/dongjie/LlaVa-Instruct-150k/LLaVA-Instruct-150K/my_dataset12k.json')
    parser.add_argument("--images_dir", type=str, default="../LlaVa-Instruct-150k/data/train2014")
    parser.add_argument("--output_dir", type=str, default="/home/yiyangai/Projects/dongjie/StepbyStep/Save_Folder/2024-5-9-after1dpo-13b")
    parser.add_argument("--temperature", type=float, default=0.3)
    parser.add_argument("--diversity_penalty", type=float, default=3.0)
    parser.add_argument("--num_beams", type=int, default=5)
    parser.add_argument("--num_beam_group", type=int, default=5)
    parser.add_argument("--num_token_beams", type=int, default=5)
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--max_new_tokens", type=int, default=70)
    parser.add_argument("--weight_mapping_path", type=str, default='/home/yiyangai/Projects/dongjie/5de42962e78a4485afa7a05120d78d88/key_mapping_13b.json')
    parser.add_argument("--is_ref", type=bool, default=False)
    parser.add_argument("--ref_path", type=str, default='/home/yiyangai/Projects/dongjie/StepbyStep/Save_Folder/4-26-dataset.json')
    args = parser.parse_args()
    
    eval_model(args)