File size: 15,254 Bytes
d60b1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

import logging
import warnings
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers
import torch
import gc
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence


from calc_metrics import calculate_log_sum,calculate_log_last
import torch.nn.functional as F
import logging
import time
import traceback

import datetime
doday=datetime.datetime.now().strftime("%Y-%m-%d")
# 配置日志
extra_info='fill'

# logging.basicConfig(level=logging.INFO,filename='/wangbenyou/chenghao/fersh_bench/log/app.log', filemode='a', format='%(name)s - %(levelname)s - %(message)s')
# logging.basicConfig(level=logging.INFO,filename=f'../log/app_jieduan_{extra_info}{doday}_year.log', filemode='a', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

import torch
import pdb
import json


paths=[
'/mntcephfs/data/med/fanyaxin/Qwen-7B-Chat',

] 



# file_in_data_folder='2024-01-04_18'
# file_in_data_folder='2023-12-31'
file_in_data_folder='2023-12-27'
# file_in_data_folder='2020_100'
# file_in_data_folder='2020'
# file_in_data_folder='2014'
# file_in_data_folder='2017'
# file_in_data_folder='2019'
# file_in_data_folder='2019'
# file_in_data_folder='rephrase_MMLU'
# file_in_data_folder='mock_MMLU'

# mmlu_mock_concat

# not arxiv not year, but rep MMLU
# 你的语料列表
import get_text
# file_dic_list_strings=get_text.file_dic_list_strings
limit_lines_per_file=10
file_dic_list_strings=get_text.get_text_from(file_in_data_folder,limit=limit_lines_per_file)
# file_dic_list_strings=get_text.get_mmlu_rephrase_text(directory='/mntnfs/med_data5/chenghao/fresh_eval/data/mmlu_rephrase_concat/gpt-4-1106-preview/')
# file_dic_list_strings=get_text.get_mmlu_rephrase_text(directory='/mntnfs/med_data5/chenghao/fresh_eval/data/mmlu_mock_concat/gpt-4-1106-preview/')
        


# file_in_data_folder='2024-01-03'

def get_rwkv_model_tokenizer(model_name):
    os.environ['RWKV_JIT_ON'] = '1'
    os.environ["RWKV_CUDA_ON"] = '1'
    from rwkv.model import RWKV
    from rwkv.utils import PIPELINE
    model=RWKV(model=model_name, strategy='cuda fp16')
    pipeline = PIPELINE(model, r"rwkv_vocab_v20230424")
    tokenizer = pipeline.tokenizer
    return model,tokenizer

def get_mamba_model_tokenizer(model_name):
    from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
    device = "cuda"
    tokenizer = AutoTokenizer.from_pretrained("/mntcephfs/data/med/chenghao/models/gpt-neox-20b_tokenizer")
    model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=torch.float16)
    return model,tokenizer


def get_HF_model_tokenizer(model_name):
    if 'llama_hf_13b' in model_name:
        tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, unk_token="<unk>") 
    else:
        from transformers import AutoTokenizer, AutoModelForCausalLM

        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    
    if 'zephyr' in model_name.lower():
        model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto").eval()
    
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True).eval()
    return model,tokenizer

limit_lines_per_file=10

def run_model_on_dic(config):
    config['clear_log_first']=True
    logging.info("start up")
    paths=config['model_path']
    file_dic_list_strings=config['file_dic_list_strings']
    detail_log_base=config['detail_log_path']
    extract_log_base=config['extract_log_path']
    max_sequence_length,max_str_len,limit_lines_per_file=config['max_sequence_length'],config['max_str_len'],config['limit_lines_per_file']
    
    for model_name in tqdm(paths):
        model_name=model_name.strip()
        tmp_path=model_name[:-1] if model_name[-1]=='/' else model_name
        short_model_name=tmp_path.split('/')[-1]
        config['detail_log_path']=detail_log_base.replace('TOFILL',f'{short_model_name}')
        config['extract_log_path']=extract_log_base.replace('TOFILL',f'{short_model_name}')
        if 'clear_log_first' in config.keys() and config['clear_log_first'] is True:
            with open( config['extract_log_path'],'w')as f:
                f.write('')
            with open( config['detail_log_path'],'w')as f:
                f.write('')
            print(f'\n log cleared! ')

        logging.basicConfig(level=logging.INFO,filename=config['detail_log_path'], filemode='a', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',force=True)

        

        print()
        print('model_path',model_name)
        print(f'extract_log_path:{config["extract_log_path"]}\ndetail_log_path:{config["detail_log_path"]}')
        print()

        try:
            if config['model_type']=='RWKV':#'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )):
                model,tokenizer=get_rwkv_model_tokenizer(model_name)


            elif config['model_type']=='MAMBA':#('mamba' in model_name) or ('MAMBA'in model_name ):
                model,tokenizer=get_mamba_model_tokenizer(model_name)


            elif config['model_type']=='HF':#'HF' in model_name:

                model,tokenizer=get_HF_model_tokenizer(model_name)
                print(f'model device:{model.device}')
                print('[tokenizer.cls_token]',[tokenizer.cls_token])
                print('[tokenizer.sep_token]',[tokenizer.sep_token])
            else:
                raise Exception('model type not found')

            # === get model and tokenizer
                
            for file_name,corpus in file_dic_list_strings.items():

                tokenized_corpus=[]
                for text in corpus:
                    text=text[:max_str_len]
                    if config['model_type']=='RWKV':
                              #'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )):
                        tokenized_corpus.append(tokenizer.encode(text))

                    elif 'HF' in model_name and ('RWKV' in model_name):
                        tokenized_corpus.append(tokenizer(text, return_tensors="pt")['input_ids'])

                    elif ('mamba' in model_name) or ('MAMBA'in model_name ):
                        device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
                        tokenized_corpus.append(tokenizer(text, return_tensors="pt").input_ids.to(device=device))

                    else:
                        tokens = tokenizer.tokenize(text)
                        if tokenizer.cls_token:# attention here is not [None]
                            tokens = [tokenizer.cls_token] + tokens 
                        if tokenizer.sep_token:
                            tokens = tokens +[tokenizer.sep_token]
                        input_ids = tokenizer.convert_tokens_to_ids(tokens)
                        tokenized_corpus.append(input_ids)
                        # tokenized_corpus.append(tokenizer(text, return_tensors="pt")['input_ids'])

                

                processed_sequences = []

                # 遍历 tokenized_corpus,截断或补全序列
                for sequence in tokenized_corpus:
                    # print('len(sequence)',len(sequence))
                    if len(sequence) < max_sequence_length:
                        pass
                        # 补全序列
                        # sequence = sequence + [tokenizer.pad_token_id] * (max_sequence_length - len(sequence))
                        # print(f'longer {max_sequence_length - len(sequence)}')
                    elif len(sequence) > max_sequence_length:
                        # 截断序列
                        sequence = sequence[:max_sequence_length]
                    
                    # 将处理后的序列添加到列表中
                    processed_sequences.append(sequence)


                total_loss = 0.0
                total_tokens = 0
                # pdb.set_trace()

                for enu,batch_input_ids in tqdm(enumerate(processed_sequences)):
                    # if 'test_fun_dev' in config['detail_log_path'] and enu>50:
                    #     print(f'enu:{enu} batch_input_ids: break')
                    #     break

                    batch_input_ids=torch.tensor(batch_input_ids).unsqueeze(0)

                    with torch.no_grad():
                        # 获取模型的输出
                        # pdb.set_trace()
                        if config['model_type']=='RWKV':
                        # if 'HF' not in model_name and (('RWKV' in model_name) or ('rwkv' in model_name )):
                            # print('rwkv1')
                            # pdb.set_trace()
                            # logits = model.forward(batch_input_ids.squeeze().to(torch.float32), None, full_output=True)[0]
                            logits = model.forward(batch_input_ids.squeeze().long(), None, full_output=True)[0]
                            # logits = model.forward(batch_input_ids.squeeze(), None, full_output=True)[0]
                            # print(logits.shape)   
                            '''
                            tmp=torch.tensor(batch_input_ids).unsqueeze(0)
                            logits = model.forward(batch_input_ids.squeeze().long(), None)
                            logits = model.forward(batch_input_ids.long(), None,)[0]
                            for output in outputs:print(tokenizer.decode(output.tolist(), skip_special_tokens=True))

                            '''
                            # loss = torch.nn.functional.cross_entropy(logits[ :-1, :].view(-1, logits.shape[-1]).to(torch.float32), batch_input_ids[0,1:].to(logits.device).view(-1).to(torch.float32), reduction='none')
                            loss = torch.nn.functional.cross_entropy(logits[ :-1, :].view(-1, logits.shape[-1]).to(torch.float32), batch_input_ids[0,1:].to(logits.device).view(-1), reduction='none')

                        elif config['model_type']=='MAMBA':
                            # pdb.set_trace() 
                            mamba_output = model.forward(batch_input_ids[0])#the shape should be like (1,length)
                            logits = mamba_output.logits
                            loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[0][:,1:].view(-1), reduction='none')
                            # pdb.set_trace()



                        elif config['model_type']=='HF':
                            if 'HF' in model_name and 'RWKV' in model_name:
                                # pdb.set_trace()
                                batch_input_ids=batch_input_ids.to(model.device)
                                logits = model.forward(batch_input_ids[0]).logits#the shape should be like (1,length)
                                loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[0][:,1:].view(-1), reduction='none')
                                '''
                                batch_input_ids=batch_input_ids.to(model.device)

    HuggingFace-Download-Accelerator/
    (Pdb) c
                        /mntnfs/med_data5/chenghao/fresh_eval/src/fun_base_fill_LLM.py:324: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
                                '''
                            else:
                                outputs = model(batch_input_ids)

                                # 取出模型的logits
                                if 'chatglm3-6b' in model_name:   
                                    logits = outputs.logits.float()
                                else:
                                    logits = outputs.logits

                                loss = torch.nn.functional.cross_entropy(logits[:, :-1, :].view(-1, logits.shape[-1]), batch_input_ids[:,1:].view(-1), reduction='none')
                            

                        loss_sum = loss.sum()
                        loss_mean = loss.mean()
                        losses_list = loss.tolist()

                        # 准备要写入日志的数据
                        tmp_dic = {
                            'model_name': model_name,
                            'file_name': file_name,
                            'lengths': len(batch_input_ids[0]),
                            'length_str':len(corpus[enu][:max_str_len]),
                            'loss_sum': loss_sum.item(),  # 转换为Python标准数据类型
                            'loss_mean': loss_mean.item(),
                            'losses_list': losses_list
                        }
                        import json
                        with open(config['detail_log_path'], 'a') as f:
                        
                            json.dump(tmp_dic, f)
                            f.write("\n")

                        total_loss += loss.sum().item()
                        total_tokens += batch_input_ids.numel()

                # 计算每个类别的平均损失
                # pdb.set_trace()
                average_loss = total_loss / total_tokens
                avg_str_loss = total_loss/len(tokenized_corpus)
                

                print(f"{file_name} total loss:", average_loss)
                import json

                logs = {
                "model_name": model_name,
                "file_name": file_name,
                "processed_sequences": len(processed_sequences), 
                "average_loss": average_loss,
                "avg_str_loss": avg_str_loss
                }

                # with open(f'/mntnfs/med_data5/chenghao/fresh_eval/log/year_arxiv/j_y_ans_{file_in_data_folder}.json', 'a') as f:
                with open(config['extract_log_path'], 'a') as f:

                    json.dump(logs, f)
                    f.write("\n")

                logging.info(logs)
                
        except Exception as e:
            logging.error(f"{model_name}, error:{e} ,detail:{traceback.format_exc()}")
            with open(config['extract_log_path'], 'a') as f:
                # json.dump(logs, f)
                f.write(f"{model_name} failed \n")
                print(f"{model_name} failed for {e} detail:{traceback.format_exc()}\n")

if __name__=='__main__':
    config={}
    print(file_in_data_folder)
    file_dic_list_strings=get_text.get_text_from(file_in_data_folder,limit=limit_lines_per_file)
    config['max_sequence_length'],config['max_str_len'],config['limit_lines_per_file']=2048,5000,10
    config['extract_log_path']=f'/mntnfs/med_data5/chenghao/fresh_eval/log/test_fun_dev/extract.log'
    config['detail_log_path']=f'/mntnfs/med_data5/chenghao/fresh_eval/log/test_fun_dev/detail.log'

    config['model_path']='/mntnfs/med_data5/liangjuhao/models/TinyLlama-1.1B-Chat-v0.6'#paths[:1]
    config['batch']=16
    config['model_type']='HF'

    print('start',config['model_path'])
    config['file_dic_list_strings']=file_dic_list_strings
    run_model_on_dic(config)