Edit model card

CHIP2022 医疗清单发票OCR要素提取任务 解决方案

使用CogVLM-V2在CHIP2022医疗清单发票OCR要素提取任务上进行lora微调的infer代码

权重下载

Huggingface:https://huggingface.co/starAIpark/CHIP2022_MedTable-MedInvoice_CogVLM/tree/main

Infer 代码

将代码中的LORA_PATH改为下载的文件

若需自己的图片上做infer则应更改下述代码中的 img_pathprompt

import os
import cv2
import random
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms 

import math
import numpy as np
from tqdm import tqdm
from easymodel.model import MODEL, build_model
from easymodel.loss import build_loss

from transformers import AutoModelForCausalLM, AutoModel, GPTQConfig, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training


class LoadCogVLM(object):
    def __init__(
        self, imgpath_key, img_key,
        image_size, patch_size, q_key,
        vision_text
    ):
        self.img_key = img_key
        self.imgpath_key = imgpath_key
        self.transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )

        self.image_size = image_size
        self.patch_size = patch_size
        self.vision_text = vision_text
        self.q_key = q_key


    def __call__(self, data):
        img_path = data[self.imgpath_key]
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        vision_token_num = (self.image_size // self.patch_size // 2) * (self.image_size // self.patch_size // 2) + 2
        question = '<|begin_of_text|>' + self.vision_text * vision_token_num
        data[self.q_key] = question
        data[self.img_key] = image
        return data

class VQA(object):
    def __init__(
        self, tokenizer_type,
        conv_key, max_input_len, eos_text, pad_text,
        inputids_key, lossmask_key, only_question,template_type = None, **kwargs
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type, trust_remote_code = True)
        self.conv_key = conv_key
        self.max_input_len = max_input_len
        self.eos_text = eos_text
        self.pad_text = pad_text
        self.inputids_key = inputids_key
        self.lossmask_key = lossmask_key
        self.only_question = only_question
        self.template_type = template_type
        if self.template_type == 'CogVLM':
            self.pad_id = self.tokenizer(self.pad_text, return_tensors='pt')['input_ids'][0][1]
        else:
            self.pad_id = self.tokenizer(self.pad_text, return_tensors='pt')['input_ids'][0][0]
        self.kwargs = kwargs

    def __call__(self, data):
        conv_list = data[self.conv_key]

        # 这里需要考虑多轮对话
        if self.template_type is None:
            question_text = '<|im_start|>system\nYou are a helpful assistant.<|im_end|>'
            answer_text = ''
        elif self.template_type == 'CogVLM':
            question_text = data['question']
            answer_text = ''

            
        conv_cnts = len(conv_list) / 2
        selected_cnts = random.randint(1, conv_cnts)
        for k in range(selected_cnts):
            if self.template_type is None:
                q_str = conv_list[k*2]['value']
                a_str = conv_list[k*2 + 1]['value']
                question_text += f'<|im_start|>user\n{q_str}<|im_end|>'
                if k == selected_cnts-1:
                    question_text += '<|im_start|>assistant\n'
                    answer_text = f'{a_str}<|im_end|>'
                else:
                    question_text += f'<|im_start|>assistant\n{a_str}<|im_end|>'
            elif self.template_type == 'CogVLM':
                q_str = conv_list[k*2]['value']
                a_str = conv_list[k*2 + 1]['value']
                question_text += f'Question: {q_str}'
                if k == selected_cnts-1:
                    question_text += 'Answer: '
                    answer_text = f'{a_str}' + self.eos_text
                else:
                    question_text += f'Answer: {a_str}'

        prompt_ids = self.tokenizer(question_text, return_tensors='pt')['input_ids'][0]
        if self.template_type == 'CogVLM':
            prompt_ids = prompt_ids[1:]
        
        if self.only_question:
            data[self.inputids_key] = prompt_ids
            if self.template_type == 'CogVLM':
                token_type_ids = torch.zeros(len(prompt_ids))
                ind = prompt_ids == self.tokenizer(self.kwargs['vision_text'], return_tensors = 'pt')['input_ids'][0][1]
                token_type_ids[ind] = 1
                data[self.kwargs['token_type_key']] = token_type_ids
            return data
        if self.template_type == 'CogVLM':
            answer_ids = self.tokenizer(answer_text, return_tensors = 'pt')['input_ids'][0][1:]
        else:
            answer_ids = self.tokenizer(answer_text, return_tensors = 'pt')['input_ids'][0]
        input_ids = torch.cat((prompt_ids, answer_ids))
        loss_mask = torch.cat((torch.zeros(len(prompt_ids)), torch.ones(len(answer_ids))))
        
        if self.template_type == 'CogVLM':
            token_type_ids = torch.zeros(len(input_ids))
            ind = input_ids == self.tokenizer(self.kwargs['vision_text'], return_tensors = 'pt')['input_ids'][0][1]
            token_type_ids[ind] = 1

        if len(input_ids) >= self.max_input_len:
            input_ids = input_ids[:self.max_input_len]
            loss_mask = loss_mask[:self.max_input_len]
            if self.template_type == 'CogVLM':
                token_type_ids = token_type_ids[:self.max_input_len]
        else:
            padding_len = self.max_input_len - len(input_ids)
            input_pad_id = torch.ones(padding_len, dtype = input_ids.dtype) * self.pad_id
            input_ids = torch.cat((input_ids, input_pad_id))

            lossmask_pad_id = torch.zeros(padding_len)
            loss_mask = torch.cat((loss_mask, lossmask_pad_id), axis = 0)

            if self.template_type == 'CogVLM':
                token_type_pad_ids = torch.zeros(padding_len)
                token_type_ids = torch.cat((token_type_ids, token_type_pad_ids), axis = 0)
        
        data[self.inputids_key] = input_ids
        data[self.lossmask_key] = loss_mask
        if self.template_type == 'CogVLM':
            data[self.kwargs['token_type_key']] = token_type_ids
        return data        

def load_lora(model, weight_path):
    state_dict = torch.load(weight_path)
    model.load_state_dict(state_dict, strict=False)
    return model

class CogVLMSFT(nn.Module):
    def __init__(
        self, model_path, lora_cfg, logger = None
    ):
        super().__init__()
        self.GPU_ID = int(os.environ.get('LOCAL_RANK') or 0)
        self.device = f"cuda:{self.GPU_ID}"

        self.model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype = torch.bfloat16,
            low_cpu_mem_usage = True, trust_remote_code = True,
            device_map = self.device
        )

        if lora_cfg is not None:
            lora_config = LoraConfig(**lora_cfg)
            self.model = get_peft_model(self.model, lora_config)
            self.model.print_trainable_parameters()

        self.loss = build_loss(dict(type = 'CrossEntropyLoss', loss_weight = 1.0))
        self.logger = logger


    # compute train losses
    def forward(self, data):
        input_ids = data['input_ids']
        loss_mask = data['loss_mask']
        labels = input_ids[:,1:]
        loss_mask = loss_mask[:,1:]

        pred_logits = self.model(
            input_ids = data['input_ids'][:,:-1],
            images = data['imgs'][:, None, ...],
            token_type_ids = data['token_type_ids'][:,:-1]
        ).logits

        loss = self.loss(pred = pred_logits, target = labels, mask = loss_mask)
        output = dict()
        output['loss'] = loss
        output['pred'] = pred_logits
        output['label'] = labels
        output['loss_mask'] = loss_mask

        return output

    def gen(self, data):
        output = self.model.generate(
            input_ids = data['input_ids'][None, ...].to('cuda'),
            token_type_ids = data['token_type_ids'][None, ...].to('cuda'),
            images = data['imgs'][None, None, ...].to('cuda').to(torch.bfloat16),
            max_new_tokens = 2048,
            do_sample = True
        )
        return output[0][len(data['input_ids']):]

# 加载模型
MODEL_PATH = 'THUDM/cogvlm2-llama3-chinese-chat-19B'
LORA_PATH = 'cogvlm_chip2022_lora.bin'
model_cfg = dict(
    model_path = MODEL_PATH,
    lora_cfg = dict(
        r = 128, target_modules = [
            'attention.query_key_value',
            'self_attn.language_expert_query_key_value',
            'self_attn.vision_expert_query_key_value',
        ],
        lora_alpha = 256, lora_dropout = 0.05,
        modules_to_save = None
    )
)
model = CogVLMSFT(**model_cfg)
model = load_lora(model, LORA_PATH)

# 加载pipeline
load_cfg = dict(imgpath_key = 'img_path', 
         image_size = 1344, patch_size = 14, img_key = 'imgs',
         vision_text = '<|reserved_special_token_0|>', q_key = 'question')
vqa_cfg = dict(type = 'VQA', tokenizer_type = MODEL_PATH,
         conv_key = 'conversations', max_input_len = 2560 + 512,
         eos_text = '<|end_of_text|>', pad_text = '<|reserved_special_token_0|>',
         inputids_key = 'input_ids', lossmask_key = 'loss_mask',
         only_question = True, template_type = 'CogVLM',
         vision_text = '<|reserved_special_token_0|>', token_type_key = 'token_type_ids')
pipeline_load = LoadCogVLM(**load_cfg)
pipeline_vqa = VQA(**vqa_cfg)



img_path = 'info_extra_demo.jpg'
prompt = '请给出图中票据代码的值,若图中不存在票据代码请输出\'无\''

conv =  [
    dict(value = prompt),
    dict(value = '')
]
data = dict(
    img_path = img_path,
    conversations = conv
)
data = pipeline_load(data)
data = pipeline_vqa(data)

print('img_path:\t', img_path)
print('prompt:\t', prompt)
with torch.no_grad():
    answer = model.gen(data)
    answer = pipeline_vqa.tokenizer.decode(answer)
    answer = answer.replace('<|end_of_text|>', '')


print('model output:\t', answer)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .