|
import os |
|
import deepspeed |
|
|
|
from tqdm import tqdm |
|
import shutil |
|
os.environ['HF_ENDPOINT']="https://hf-mirror.com" |
|
from qwenva import tokenizer |
|
from qwenva import processor |
|
from qwenva import qwenva |
|
images_file_path='./data/download/llava-v1.5-instruct/coco/train2017' |
|
|
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
import json |
|
from PIL import Image |
|
import json |
|
with open('/root/autodl-tmp/LLaVA-Instruct-150K/llava_instruct_150k.json', 'r', encoding='utf-8') as f: |
|
chat_data = json.load(f) |
|
import torch |
|
image_token=tokenizer.encode('<image>')[0] |
|
pad_token=tokenizer.pad_token_id |
|
image_token=tokenizer.encode('<image>')[0] |
|
pad_token=tokenizer.pad_token_id |
|
def process_data(sample,max_len=8012): |
|
conversations=sample['conversations'] |
|
labels=[] |
|
input_ids=[] |
|
flag=0 |
|
messages=[] |
|
input_ids=[] |
|
try: |
|
for index,item in enumerate(conversations): |
|
if item['from']=='human': |
|
old_input_ids=input_ids |
|
messages.append({'role':'user','content':item['value']}) |
|
input_ids=tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True |
|
) |
|
|
|
labels+=[-100]*(len(input_ids)-len(old_input_ids)) |
|
if index==flag: |
|
try: |
|
image_index=input_ids.index(image_token) |
|
labels[image_index]=image_token |
|
except ValueError: |
|
print("image token not found") |
|
flag=index+1 |
|
continue |
|
elif item['from']=='gpt': |
|
old_input_ids=input_ids |
|
messages.append({'role':'assistant','content':item['value']}) |
|
input_ids=tokenizer.apply_chat_template( |
|
messages |
|
) |
|
labels+=input_ids[len(old_input_ids):] |
|
except: |
|
print("error in process_data_1") |
|
exit() |
|
|
|
try: |
|
if len(input_ids)>max_len: |
|
input_ids=input_ids[:max_len] |
|
labels=labels[:max_len] |
|
attention_mask=[1]*len(input_ids) |
|
else: |
|
attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids)) |
|
input_ids+=[pad_token]*(max_len-len(input_ids)) |
|
labels+=[-100]*(max_len-len(labels)) |
|
except: |
|
print("error in process_data_2") |
|
exit() |
|
|
|
try: |
|
input_ids=torch.tensor(input_ids) |
|
attention_mask=torch.tensor(attention_mask) |
|
labels=torch.tensor(labels) |
|
image_index=torch.tensor(image_index) |
|
except: |
|
print("error in tensor") |
|
exit() |
|
return { |
|
'input_ids':input_ids, |
|
'attention_mask':attention_mask, |
|
'labels':labels, |
|
'image_idx':image_index |
|
} |
|
|
|
|
|
import os |
|
import torch |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
class MyDataset(Dataset): |
|
def __init__(self, images_file_path,data,max_len=1024): |
|
self.max_len=max_len |
|
self.images_file_path = images_file_path |
|
self.data = data |
|
self.max_len=max_len |
|
def __len__(self): |
|
return len(self.data) |
|
def __getitem__(self, index): |
|
output_=process_data(self.data[index],max_len=self.max_len) |
|
img_path=os.path.join(self.images_file_path,self.data[index]['image']) |
|
try: |
|
img=Image.open(img_path) |
|
except: |
|
print(f"image {img_path} not found") |
|
output_['labels']=torch.tensor([-100]*self.max_len) |
|
input_pixel= processor(images=img, return_tensors="pt") |
|
output_['input_pixel']=input_pixel['pixel_values'].squeeze() |
|
return output_ |
|
|
|
|
|
|
|
dataset=MyDataset(images_file_path,chat_data,max_len=2048) |
|
train_loader=DataLoader(dataset,batch_size=8,shuffle=True) |
|
import argparse |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
qwenva=qwenva.to(device) |
|
model_engine,optimizer,_,_=deepspeed.initialize( |
|
model=qwenva, |
|
args=argparse.Namespace(), |
|
model_parameters=qwenva.parameters(), |
|
config_params="./deepspeed_config.json" |
|
) |
|
|
|
|
|
|
|
|
|
for name, param in model_engine.module._orig_mod.text_embedding.named_parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
|
|
|
|
|
|
for name,param in model_engine.module._orig_mod.lm_head.named_parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
for name,param in model_engine.module._orig_mod.transformer.named_parameters(): |
|
param.requires_grad = True |
|
|
|
for name,param in model_engine.module._orig_mod.named_parameters(): |
|
if param.requires_grad: |
|
print(f"Layer: {name}, Requires Grad: {param.requires_grad}") |
|
|
|
|
|
|
|
import torch.nn as nn |
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
accumulation_steps = 1 |
|
|
|
def train(model_engine, train_dataloader, loss_fn, device, epochs): |
|
model_engine.train() |
|
|
|
for epoch in range(epochs): |
|
|
|
with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar: |
|
|
|
try: |
|
for batch_idx, batch in enumerate(train_dataloader): |
|
|
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
input_pixel = batch['input_pixel'].to(device) |
|
labels = batch['labels'].to(device) |
|
image_idx=batch['image_idx'].to(device) |
|
logits = model_engine(input_ids, attention_mask, input_pixel,image_idx) |
|
|
|
max_logits= logits.max(dim=-1, keepdim=True)[0] |
|
stable_logits= logits - max_logits |
|
loss= loss_fn(stable_logits[:, :-1, :].reshape(-1, stable_logits.shape[-1]), labels[:, 1:].reshape(-1).clone()) |
|
model_engine.backward(loss) |
|
if (batch_idx+1)%accumulation_steps==0: |
|
model_engine.step() |
|
pbar.update(1) |
|
pbar.set_postfix(loss=loss.item()) |
|
if (batch_idx+1)%6000==0: |
|
|
|
if os.path.exists("./best_model_2"): |
|
shutil.rmtree("./best_model_2") |
|
os.makedirs("./best_model_2") |
|
model_engine.save_checkpoint("./best_model_2") |
|
torch.save(model_engine.module.state_dict(), "./compiled_model_2.pth") |
|
print(f" model saved at batch {batch_idx+1}") |
|
except Exception as e: |
|
print(f"error in train {e}") |
|
|
|
train(model_engine, train_loader, loss_fn, device, epochs=2) |
|
|