| | import cv2 |
| | import numpy as np |
| | from PIL import Image |
| | from concurrent.futures import ThreadPoolExecutor |
| | from config.configu import * |
| | from models.model import * |
| | from models.similarity import * |
| | from sklearn.cluster import KMeans |
| | from utils.utils import * |
| | import warnings |
| | from typing import Any, List, Optional, Tuple, Union |
| | import torch |
| | import random |
| | import torch.utils.checkpoint |
| | import transformers |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, |
| | LlamaTokenizer) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import ModelOutput, logging |
| |
|
| | from .configuration_internvl_chat import InternVLChatConfig |
| | from .conversation import get_conv_template |
| | from .modeling_intern_vit import InternVisionModel |
| | from .modeling_internlm2 import InternLM2ForCausalLM |
| |
|
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | def coord_transform(box,return_4=True): |
| | if return_4: |
| | return [box[0][0],box[0][1],box[1][0],box[1][1]] |
| | else: |
| | return [[box[0],box[1]],[box[2],box[3]]] |
| | def insert_zeros(input_ids, attention_mask, num_zeros=5): |
| |
|
| | device = input_ids.device |
| | input_ids = input_ids.cpu().clone() |
| | attention_mask = attention_mask.cpu().clone() |
| |
|
| | for _ in range(num_zeros): |
| | |
| | insert_pos = random.randint(0, input_ids.size(1)) |
| | |
| | |
| | input_ids = torch.cat((input_ids[:, :insert_pos], torch.tensor([[0]]), input_ids[:, insert_pos:]), dim=1) |
| | |
| | |
| | attention_mask = torch.cat((attention_mask[:, :insert_pos], torch.tensor([[1]]), attention_mask[:, insert_pos:]), dim=1) |
| |
|
| | |
| | input_ids = input_ids.to(device) |
| | attention_mask = attention_mask.to(device) |
| |
|
| | return input_ids, attention_mask |
| |
|
| |
|
| | def add_Gaussian_noise(input_embeds, rate=1e-1): |
| |
|
| | device = input_embeds.device |
| | input_embeds = input_embeds.cpu().clone() |
| |
|
| | mean = input_embeds.mean() |
| | std = input_embeds.std() |
| | noise = torch.randn(input_embeds.size()) * std + mean |
| | noisy_input_embeds = input_embeds + rate * noise |
| |
|
| | noisy_input_embeds = noisy_input_embeds.to(device) |
| | noisy_input_embeds = noisy_input_embeds.to(torch.bfloat16) |
| |
|
| | return noisy_input_embeds |
| |
|
| |
|
| | def version_cmp(v1, v2, op='eq'): |
| | import operator |
| |
|
| | from packaging import version |
| | op_func = getattr(operator, op) |
| | return op_func(version.parse(v1), version.parse(v2)) |
| |
|
| | def most_frequent_rgb(image_array): |
| | """找一张图片中最frequent的rgb,用于填充mask""" |
| | |
| | pixels = image_array.reshape(-1, image_array.shape[-1]) |
| | |
| | |
| | unique_pixels, counts = np.unique(pixels, axis=0, return_counts=True) |
| | |
| | |
| | most_frequent_index = np.argmax(counts) |
| | |
| | |
| | most_frequent_pixel = unique_pixels[most_frequent_index] |
| | frequency = counts[most_frequent_index] |
| | return most_frequent_pixel, frequency |
| |
|
| | def most_frequent_rgb_fast(image_array): |
| | """快速查找图片中最频繁的RGB值,不返回频率""" |
| | |
| | flattened = image_array.reshape(-1, 3) |
| | rgb_ints = flattened[:, 0] * 256**2 + flattened[:, 1] * 256 + flattened[:, 2] |
| |
|
| | |
| | counts = np.bincount(rgb_ints) |
| |
|
| | |
| | most_frequent_index = np.argmax(counts) |
| |
|
| | |
| | r = (most_frequent_index // 256**2) % 256 |
| | g = (most_frequent_index // 256) % 256 |
| | b = most_frequent_index % 256 |
| |
|
| | return (r, g, b) |
| |
|
| |
|
| |
|
| | def mask_area(image_array,coords,color): |
| | """对一张图片在框定的一系列box进行mask""" |
| | |
| | |
| | for coord in coords: |
| | x1, y1, x2, y2 = coord |
| | image_array[y1:y2, x1:x2] =color |
| |
|
| | return image_array |
| |
|
| |
|
| | class InternVLChatModel(PreTrainedModel): |
| | config_class = InternVLChatConfig |
| | main_input_name = 'pixel_values' |
| | _supports_flash_attn_2 = True |
| | _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer'] |
| |
|
| | def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None): |
| | super().__init__(config) |
| |
|
| | assert version_cmp(transformers.__version__, '4.36.2', 'ge') |
| | image_size = config.force_image_size or config.vision_config.image_size |
| | patch_size = config.vision_config.patch_size |
| | self.patch_size = patch_size |
| | self.select_layer = config.select_layer |
| | self.template = config.template |
| | |
| | self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) |
| | |
| | self.downsample_ratio = config.downsample_ratio |
| | self.ps_version = config.ps_version |
| |
|
| |
|
| | |
| | self.mu_sigma=torch.load(NORM_PARAMS_PATH)['weight'] |
| | self.mu=self.mu_sigma[:,0].reshape((-1,1)) |
| | self.sigma=self.mu_sigma[:,1].reshape((-1,1)) |
| | self.normed_emb,self.mu_sigma=self.load_normed_tok_embeddings(load_checkboard=True) |
| | self.resampler=load_perceiver_resampler_2(PERCEIVER_CHECKPOINT,num_layers=4) |
| | |
| | self.sorter=load_orderformer(ORDERFORMER_CHECKPOINT) |
| |
|
| |
|
| | logger.info(f'num_image_token: {self.num_image_token}') |
| | logger.info(f'ps_version: {self.ps_version}') |
| | |
| | |
| | |
| | if vision_model is not None: |
| | self.vision_model = vision_model |
| | else: |
| | self.vision_model = InternVisionModel(config.vision_config) |
| | if language_model is not None: |
| | self.language_model = language_model |
| | else: |
| | if config.llm_config.architectures[0] == 'LlamaForCausalLM': |
| | self.language_model = LlamaForCausalLM(config.llm_config) |
| | elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM': |
| | self.language_model = InternLM2ForCausalLM(config.llm_config) |
| | else: |
| | raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') |
| | |
| |
|
| | vit_hidden_size = config.vision_config.hidden_size |
| | llm_hidden_size = config.llm_config.hidden_size |
| |
|
| | self.mlp1 = nn.Sequential( |
| | nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), |
| | nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), |
| | nn.GELU(), |
| | nn.Linear(llm_hidden_size, llm_hidden_size) |
| | ) |
| | |
| | self.img_context_token_id = None |
| | self.conv_template = get_conv_template(self.template) |
| | self.system_message = self.conv_template.system_message |
| | def load_normed_tok_embeddings(self,vocab_size=92553, llm_hidden_size=4096,load_checkboard=False): |
| | tok_embeddings = nn.Embedding(vocab_size, llm_hidden_size, padding_idx=2).to_empty(device=torch.device('cuda')).to(torch.bfloat16) |
| | tok_embeddings.load_state_dict(torch.load(NORM_TOK_EMBEDDING_PATH, weights_only=True, map_location="cpu")) |
| | if load_checkboard: |
| | checkboard_norm=torch.load(NORM_PARAMS_PATH) |
| | |
| | return tok_embeddings,checkboard_norm['weight'] |
| | return tok_embeddings |
| | |
| | def forward( |
| | self, |
| | pixel_values: torch.FloatTensor, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | image_flags: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | image_flags = image_flags.squeeze(-1) |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
|
| | vit_embeds = self.extract_feature(pixel_values) |
| | vit_embeds = vit_embeds[image_flags == 1] |
| | vit_batch_size = pixel_values.shape[0] |
| |
|
| | B, N, C = input_embeds.shape |
| | input_embeds = input_embeds.reshape(B * N, C) |
| |
|
| | if torch.distributed.get_rank() == 0: |
| | print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') |
| |
|
| | input_ids = input_ids.reshape(B * N) |
| | selected = (input_ids == self.img_context_token_id) |
| | try: |
| | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) |
| | except Exception as e: |
| | vit_embeds = vit_embeds.reshape(-1, C) |
| | print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' |
| | f'vit_embeds.shape={vit_embeds.shape}') |
| | n_token = selected.sum() |
| | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] |
| |
|
| | input_embeds = input_embeds.reshape(B, N, C) |
| |
|
| | outputs = self.language_model( |
| | inputs_embeds=input_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | logits = outputs.logits |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | loss_fct = CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) |
| | shift_labels = shift_labels.view(-1) |
| | |
| | shift_labels = shift_labels.to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | def pixel_shuffle(self, x, scale_factor=0.5): |
| | n, w, h, c = x.size() |
| | |
| | x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) |
| | |
| | x = x.permute(0, 2, 1, 3).contiguous() |
| | |
| | x = x.view(n, int(h * scale_factor), int(w * scale_factor), |
| | int(c / (scale_factor * scale_factor))) |
| | if self.ps_version == 'v1': |
| | warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " |
| | 'which results in a transposed image.') |
| | else: |
| | x = x.permute(0, 2, 1, 3).contiguous() |
| | return x |
| |
|
| | def extract_feature(self, pixel_values): |
| | if self.select_layer == -1: |
| | vit_embeds = self.vision_model( |
| | pixel_values=pixel_values, |
| | output_hidden_states=False, |
| | return_dict=True).last_hidden_state |
| | else: |
| |
|
| | vit_embeds = self.vision_model( |
| | pixel_values=pixel_values, |
| | output_hidden_states=True, |
| | return_dict=True).hidden_states[self.select_layer] |
| | vit_embeds = vit_embeds[:, 1:, :] |
| |
|
| | h = w = int(vit_embeds.shape[1] ** 0.5) |
| | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
| | vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) |
| | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
| |
|
| | vit_embeds = self.mlp1(vit_embeds) |
| | return vit_embeds |
| | |
| | @torch.no_grad() |
| | def calli_align(self,img_path,detect_model, drop_zero = False, use_hard_vector_quant=False,save_path=None,verbose=False): |
| | def dynamic_read(img_path,mode='c'): |
| | |
| | if isinstance(img_path, str): |
| | img = cv2.imread(img_path) |
| | |
| | if img is None: |
| | try: |
| | img = Image.open(img_path).convert("RGB") |
| | img = np.array(img) |
| | except: |
| | raise ValueError(f"Image at path {img_path} could not be loaded.") |
| | |
| | elif isinstance(img_path, Image.Image): |
| | img = np.array(img_path) |
| | |
| | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| | |
| | else: |
| | raise TypeError(f"Unsupported image type: {type(img_path)}") |
| | if mode=='i': |
| | img=Image.fromarray(img).convert("RGB") |
| | return img |
| | import time |
| | def iterative_only_boxes(model,jpg_path): |
| | |
| | image = dynamic_read(jpg_path) |
| | |
| | image_array = np.array(image) |
| | |
| | h, w, channels = image.shape |
| | boxes=[] |
| | |
| | |
| | color=most_frequent_rgb_fast(image_array) |
| | while True: |
| | res=model(image_array,verbose=False)[0] |
| | |
| | to_be_masked=[] |
| | for box in res.boxes: |
| | xyxy = box.xyxy.squeeze().tolist() |
| | x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) |
| | to_be_masked.append([x1,y1,x2,y2]) |
| | boxes.extend(to_be_masked) |
| | if len(to_be_masked)>250: |
| | image_array=mask_area(image_array,to_be_masked,color) |
| | else: |
| | break |
| | |
| | boxes=[[[max(item[0],0),max(item[1],0)],[min(item[2],w),min(item[3],h)]]for item in boxes] |
| | |
| |
|
| | i=0 |
| | length=len(boxes) |
| | while i<length: |
| | j=0 |
| | main_box=boxes[i] |
| | while j<length: |
| | if i==j: |
| | j+=1 |
| | continue |
| | iou=calculate_iou(coord_transform(main_box),coord_transform(boxes[j])) |
| | if iou>0.8: |
| | rm = boxes[j] |
| | boxes.remove(rm) |
| | if j<i: |
| | i-=1 |
| | length-=1 |
| | j-=1 |
| | j+=1 |
| | i+=1 |
| |
|
| | return boxes |
| | def char2col_with_kmeans(jpg_path,boxes, verbose=False): |
| | |
| | def kmeans_boxes(bounding_boxes): |
| | areas = [ (box[1][0] - box[0][0])*(box[1][1] - box[0][1]) for box in bounding_boxes] |
| |
|
| |
|
| | |
| | areas = np.array(areas).reshape(-1, 1) |
| |
|
| | |
| | kmeans = KMeans(n_clusters=2, random_state=0).fit(areas) |
| |
|
| | |
| | labels = kmeans.labels_ |
| |
|
| | |
| | group_0 = [] |
| | group_1 = [] |
| |
|
| | for i, label in enumerate(labels): |
| | if label == 0: |
| | group_0.append(bounding_boxes[i]) |
| | else: |
| | group_1.append(bounding_boxes[i]) |
| | |
| | group_0 = sorted(group_0, key = lambda x: (x[1][0]-x[0][0]), reverse=True) |
| | group_1 = sorted(group_1, key = lambda x: (x[1][0]-x[0][0]), reverse=True) |
| |
|
| | if (group_1[0][1][0] - group_1[0][0][0]) > (group_0[0][1][0] - group_0[0][0][0]): |
| | g1_hs = np.array([x[1][1]-x[0][1] for x in group_1]).mean() |
| | thr1 = 1*( group_1[-1][1][0] - group_1[-1][0][0]) |
| | thr2 = 0.8*g1_hs |
| | |
| | new_0 = [] |
| | for ele in group_0: |
| | if (ele[1][0] - ele[0][0]) >= thr1 or (ele[1][1] - ele[0][1]) >= thr2 or (areas.min()/(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) <= 1/5 and areas.mean() / ((ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1])) <= 1.3): |
| | group_1.append(ele) |
| | else: |
| | new_0.append(ele) |
| |
|
| | grouped_luokuan = merge_boxes(new_0.copy()) |
| | |
| | final_ = [] |
| | for ele in new_0: |
| | if ele in grouped_luokuan: |
| | |
| | group_1.append(ele) |
| | else: |
| | final_.append(ele) |
| | group_0 = final_ |
| | |
| | elif (group_0[0][1][0] - group_0[0][0][0]) > (group_1[0][1][0] - group_1[0][0][0]): |
| | g0_hs = np.array([x[1][1]-x[0][1] for x in group_0]).mean() |
| | thr1 = 1*( group_0[-1][1][0] - group_0[-1][0][0]) |
| | thr2 = 0.8*g0_hs |
| | |
| | new_1 = [] |
| | for ele in group_1: |
| | if (ele[1][0] - ele[0][0]) >= thr1 or (ele[1][1] - ele[0][1]) >= thr2 or (areas.min()/(ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1]) <= 1/5 and areas.mean() / ((ele[1][0] - ele[0][0])*(ele[1][1] - ele[0][1])) <=1.3): |
| | |
| | group_0.append(ele) |
| | else: |
| | new_1.append(ele) |
| | |
| | grouped_luokuan = merge_boxes(new_1.copy()) |
| | |
| | final_ = [] |
| | for ele in new_1: |
| | if ele in grouped_luokuan: |
| | group_0.append(ele) |
| | else: |
| | final_.append(ele) |
| | group_1 = final_ |
| | |
| | return group_0,group_1 |
| |
|
| | def toint(lst): |
| | if len(lst)==2: |
| | return [[int(lst[0][0]),int(lst[0][1])],[int(lst[1][0]),int(lst[1][1])]] |
| | else: |
| | return [int(lst[0]),int(lst[1]),int(lst[2]),int(lst[3])] |
| | img = dynamic_read(jpg_path) |
| | h, w, channels = img.shape |
| |
|
| | normalized_boxes=[[[item[0][0]/w,item[0][1]/h],[item[1][0]/w,item[1][1]/h]] for item in boxes] |
| | S=np.array([(item[0][0]-item[1][0])*(item[0][1]-item[1][1]) for item in normalized_boxes]) |
| | |
| | |
| | |
| | |
| | coef_var=np.std(S)/np.mean(S) |
| | boxes2class=None |
| | col2class=None |
| | |
| | if coef_var>0.66 and S.min()/S.mean() <= 1/8: |
| | |
| | boxes1,boxes2=kmeans_boxes(normalized_boxes) |
| | |
| | |
| | boxes1=[[[item[0][0]*w,item[0][1]*h],[item[1][0]*w,item[1][1]*h]] for item in boxes1] |
| | boxes2=[[[item[0][0]*w,item[0][1]*h],[item[1][0]*w,item[1][1]*h]] for item in boxes2] |
| | columns1=merge_boxes(boxes1.copy()) |
| | columns2=merge_boxes(boxes2.copy()) |
| | |
| | columns=columns1+columns2 |
| | boxes2class={1:[toint(item) for item in boxes1],2:[toint(item) for item in boxes2]} |
| | col2class={1:[toint(item) for item in columns1],2:[toint(item) for item in columns2]} |
| | |
| |
|
| | else: |
| | columns=merge_boxes(boxes.copy()) |
| |
|
| |
|
| | results={"imageHeight":h,"imageWidth":w,"shapes":[{"points":toint(col)} for col in columns], |
| | "boxes2class":boxes2class,"col2class":col2class} |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return results |
| | |
| | def sort_boxes(jpg,detector,model,thres=0.8): |
| | |
| | boxes=iterative_only_boxes(detector,jpg) |
| | |
| | data=char2col_with_kmeans(jpg,boxes,verbose=False) |
| | |
| | res=model.predict(data,jpg) |
| | final_results=[] |
| | for idx,col in res.items(): |
| | lst=[] |
| | for item in boxes: |
| | ratio=calculate_iou(col,[item[0][0],item[0][1],item[1][0],item[1][1]],mini=True) |
| | |
| | if ratio>=thres: |
| | lst.append([item[0][0],item[0][1],item[1][0],item[1][1]]) |
| | lst=sorted(lst, key=lambda item: (item[1]+item[3])/2) |
| | final_results.extend(lst) |
| | |
| | return final_results |
| | if img_path is None: |
| | return None,None |
| | |
| | st=time.time() |
| | boxes=sort_boxes(img_path,detect_model,self.sorter) |
| | ed=time.time() |
| | if verbose: |
| | print(f"YOLO+Orderformer {ed-st:.2f}s") |
| | if save_path!=None: |
| | frame = dynamic_read(img_path) |
| | name=img_path.split("/")[-1] |
| | for i,box in enumerate(boxes): |
| | |
| | xyxy = box |
| | x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) |
| | colo = (255,0,0) |
| | cv2.rectangle(frame, (x1, y1), (x2, y2), thickness=2,color=colo,lineType=cv2.LINE_AA) |
| | |
| | cv2.putText(frame, str(i+1), ((x1+x2)//2, (y1+y2)//2), cv2.FONT_HERSHEY_SIMPLEX, 1.5, colo, thickness=2, lineType=cv2.LINE_AA) |
| | print(save_path+"oredered_result_"+name) |
| | cv2.imwrite(save_path+"oredered_result_"+name,frame) |
| | |
| | st=time.time() |
| | pixel_values=[] |
| | img=np.array(dynamic_read(img_path,mode='i').convert("RGB")) |
| | |
| | for xyxy in boxes: |
| | x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3]) |
| | sub_img=Image.fromarray(img[y1:y2,x1:x2]) |
| | pixel_values.append(load_image_2(sub_img).to(torch.bfloat16).cuda()) |
| | ed1=time.time() |
| | results=torch.cat(pixel_values) |
| | |
| | image_embeddings=self.extract_feature(results) |
| | ed2=time.time() |
| | output=self.resampler(image_embeddings) |
| | ed3=time.time() |
| | |
| | |
| | |
| | outs=vq_cos_sim(self.normed_emb,output, use_hard_vector_quant) |
| | |
| | ed4=time.time() |
| | if verbose: |
| | print(f"Get pixel values {ed1-st:.2f}s") |
| | print(f"extract feat {ed2-ed1:.2f}s") |
| | print(f"Resampler forward {ed3-ed2:.2f}") |
| | print(f"vq cos sim {ed4-ed3:.2f}s") |
| | if use_hard_vector_quant: |
| | indices, cos_sim_values = outs |
| | |
| | thresh = 0.5 |
| | else: |
| | indices = outs |
| |
|
| | if use_hard_vector_quant: |
| | print("Dynamic vector quantization...") |
| | |
| | below_mask = (cos_sim_values <= thresh).to(torch.bfloat16).unsqueeze(-1) |
| | |
| | output = output * (1-below_mask) + self.normed_emb.weight[indices] * below_mask |
| | |
| | |
| | flattened_output = output.view(-1, output.shape[-1]) |
| | flattened_indices = indices.view(-1) |
| |
|
| | if drop_zero: |
| | filtered_indices=flattened_indices[flattened_indices!=0] |
| | filtered_output=flattened_output[flattened_indices!=0] |
| |
|
| |
|
| | sigma_flat = self.sigma[filtered_indices] |
| | mu_flat = self.mu[filtered_indices] |
| |
|
| | sigma_flat = sigma_flat.expand(-1, filtered_output.shape[-1]) |
| | mu_flat = mu_flat.expand(-1, filtered_output.shape[-1]) |
| | back_to_origin_flat = filtered_output * sigma_flat + mu_flat |
| | |
| | else: |
| | sigma_flat = self.sigma[flattened_indices] |
| | mu_flat = self.mu[flattened_indices] |
| | sigma_flat = sigma_flat.expand(-1, flattened_output.shape[-1]) |
| | mu_flat = mu_flat.expand(-1, flattened_output.shape[-1]) |
| | back_to_origin_flat = flattened_output * sigma_flat + mu_flat |
| | |
| | |
| | return back_to_origin_flat, indices |
| |
|
| | def find_coordinates(self,text): |
| | import re |
| |
|
| | numbers = re.findall(r'\d+', text) |
| |
|
| | numbers = [int(num) for num in numbers] |
| | return numbers |
| | def chat_ocr(self, tokenizer, detect_model,img_path, questions, generation_config, num_patches_list=None, |
| | history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', |
| | IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', ALIGNED_TOKEN="[UNUSED_TOKEN_140]",verbose=False, image_counts=None,batch=False, |
| | use_p=True, drop_zero=False, hard_vq=False, repetition_penalty=1.5,region_wise=False): |
| |
|
| |
|
| |
|
| |
|
| | pixel_values = None |
| | if img_path is not None: |
| | try: |
| | if region_wise: |
| | img=np.array(Image.open(img_path).convert("RGB")) |
| | coord=self.find_coordinates(questions) |
| | x1,x2,y1,y2=coord |
| | sub_img=Image.fromarray(img[y1:y2,x1:x2]) |
| | |
| | questions="输出图片中所有文字:" |
| | pixel_values=load_image(sub_img).to(torch.bfloat16).to(torch.device("cuda")) |
| | else: |
| | pixel_values=load_image(img_path).to(torch.bfloat16).to(torch.device("cuda")) |
| | except: |
| | raise FileNotFoundError |
| | if use_p: |
| | import time |
| | st=time.time() |
| | if region_wise: |
| | try: |
| | out_tokens, indices =self.calli_align(sub_img,detect_model, drop_zero = drop_zero, use_hard_vector_quant=hard_vq,verbose=verbose) |
| | except: |
| | return "检测失败" |
| | else: |
| | |
| | out_tokens, indices =self.calli_align(img_path,detect_model, drop_zero = drop_zero, use_hard_vector_quant=hard_vq,verbose=verbose) |
| | if verbose: |
| | print(f"Calli Align: {time.time()-st:.2f}s") |
| | |
| | |
| | if pixel_values is None: |
| | question=questions |
| |
|
| | if pixel_values is not None and '<image>' not in questions: |
| | question = '<image>\n' + questions |
| | |
| | elif history is None and pixel_values is None: |
| | question=questions |
| | elif '<image>' in questions: |
| | question=questions |
| |
|
| | if history is None and use_p and '[UNUSED_TOKEN_140]' not in question: |
| | question =question+'[UNUSED_TOKEN_140]'*out_tokens.shape[0] |
| | if num_patches_list is None: |
| | num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| | assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
| |
|
| | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| | self.img_context_token_id = img_context_token_id |
| |
|
| | template = get_conv_template(self.template) |
| | template.system_message = self.system_message |
| | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| |
|
| | history = [] if history is None else history |
| | for (old_question, old_answer) in history: |
| | template.append_message(template.roles[0], old_question) |
| | template.append_message(template.roles[1], old_answer) |
| | template.append_message(template.roles[0], question) |
| | template.append_message(template.roles[1], None) |
| | query = template.get_prompt() |
| |
|
| | |
| |
|
| | for num_patches in num_patches_list: |
| | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| |
|
| | query = query.replace('<image>', image_tokens, 1) |
| | |
| | model_inputs = tokenizer(query, return_tensors='pt') |
| |
|
| | input_ids = model_inputs['input_ids'].cuda() |
| | |
| | attention_mask = model_inputs['attention_mask'].cuda() |
| |
|
| | generation_config['eos_token_id'] = eos_token_id |
| |
|
| |
|
| | if use_p: |
| | generation_output = self.generate_ocr( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | reference_embeds=out_tokens, |
| | repetition_penalty=repetition_penalty, |
| | **generation_config |
| | ) |
| | else: |
| | generation_output = self.generate_ocr( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | repetition_penalty=repetition_penalty, |
| | **generation_config |
| | ) |
| | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| | response = response.split(template.sep)[0].strip() |
| | history.append((question, response)) |
| | if return_history: |
| | return response, history |
| | else: |
| | query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| | query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| |
|
| |
|
| | return response |
| | |
| | |
| | def dynamic_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, |
| | history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', |
| | IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None,batch=False,use_p=True): |
| | if use_p: |
| | self.num_image_token=3 |
| | if batch: |
| | assert isinstance(questions,list) and len(questions)>0 and isinstance(questions[0],str) |
| | if history is not None or return_history: |
| | print('Now multi-turn chat is not supported in batch_chat.') |
| | raise NotImplementedError |
| |
|
| | if image_counts is not None: |
| | num_patches_list = image_counts |
| | print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') |
| |
|
| | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| | self.img_context_token_id = img_context_token_id |
| |
|
| | if verbose and pixel_values is not None: |
| | image_bs = pixel_values.shape[0] |
| | print(f'dynamic ViT batch size: {image_bs}') |
| |
|
| | queries = [] |
| | for idx, num_patches in enumerate(num_patches_list): |
| | question = questions[idx] |
| | if pixel_values is not None and '<image>' not in question: |
| | question = '<image>\n' + question |
| | template = get_conv_template(self.template) |
| | template.append_message(template.roles[0], question) |
| | template.append_message(template.roles[1], None) |
| | query = template.get_prompt() |
| |
|
| | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| | query = query.replace('<image>', image_tokens, 1) |
| | queries.append(query) |
| |
|
| | |
| | tokenizer.padding_side = 'left' |
| | model_inputs = tokenizer(queries, return_tensors='pt', padding=True) |
| | input_ids = model_inputs['input_ids'].cuda() |
| | attention_mask = model_inputs['attention_mask'].cuda() |
| | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| | generation_config['eos_token_id'] = eos_token_id |
| | if use_p: |
| | generation_output = self.generate( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **generation_config |
| | ) |
| | else: |
| | |
| | generation_output = self.generate_origin( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **generation_config |
| | ) |
| | responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) |
| | responses = [response.split(template.sep)[0].strip() for response in responses] |
| | return responses |
| | else: |
| | assert isinstance(questions,str) |
| | if num_patches_list is None: |
| | num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| | assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
| |
|
| | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| | self.img_context_token_id = img_context_token_id |
| |
|
| | template = get_conv_template(self.template) |
| | template.system_message = self.system_message |
| | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| |
|
| | history = [] if history is None else history |
| | for (old_question, old_answer) in history: |
| | template.append_message(template.roles[0], old_question) |
| | template.append_message(template.roles[1], old_answer) |
| | template.append_message(template.roles[0], questions) |
| | template.append_message(template.roles[1], None) |
| | query = template.get_prompt() |
| |
|
| |
|
| | if verbose and pixel_values is not None: |
| | image_bs = pixel_values.shape[0] |
| | print(f'dynamic ViT batch size: {image_bs}') |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | query=f"""<|im_start|>system你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|im_end|>\n<|im_start|>user{questions}""" |
| | query = query+'<image>' |
| | for num_patches in num_patches_list: |
| | |
| | image_tokens = IMG_CONTEXT_TOKEN * self.num_image_token |
| | |
| | query = query.replace('<image>', image_tokens, 1) |
| | |
| | query+="<|im_end|>\n<|im_start|>assistant" |
| | |
| | model_inputs = tokenizer(query, return_tensors='pt') |
| | |
| |
|
| | input_ids = model_inputs['input_ids'].cuda() |
| | attention_mask = model_inputs['attention_mask'].cuda() |
| |
|
| |
|
| | generation_config['eos_token_id'] = eos_token_id |
| | if use_p: |
| | |
| | generation_output = self.generate( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **generation_config |
| | ) |
| | else: |
| | generation_output = self.generate_origin( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **generation_config |
| | ) |
| | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| | response = response.split(template.sep)[0].strip() |
| | history.append((questions, response)) |
| | if return_history: |
| | return response, history |
| | else: |
| | query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| | query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| | if verbose: |
| | print(query_to_print, response) |
| |
|
| | return response |
| |
|
| | def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, |
| | history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', |
| | IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None): |
| | |
| | if history is not None or return_history: |
| | print('Now multi-turn chat is not supported in batch_chat.') |
| | raise NotImplementedError |
| |
|
| | if image_counts is not None: |
| | num_patches_list = image_counts |
| | print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') |
| |
|
| | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| | self.img_context_token_id = img_context_token_id |
| |
|
| | if verbose and pixel_values is not None: |
| | image_bs = pixel_values.shape[0] |
| | print(f'dynamic ViT batch size: {image_bs}') |
| |
|
| | queries = [] |
| | for idx, num_patches in enumerate(num_patches_list): |
| | question = questions[idx] |
| | if pixel_values is not None and '<image>' not in question: |
| | question = '<image>\n' + question |
| | template = get_conv_template(self.template) |
| | template.append_message(template.roles[0], question) |
| | template.append_message(template.roles[1], None) |
| | query = template.get_prompt() |
| |
|
| | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| | query = query.replace('<image>', image_tokens, 1) |
| | queries.append(query) |
| |
|
| | |
| | tokenizer.padding_side = 'left' |
| | model_inputs = tokenizer(queries, return_tensors='pt', padding=True) |
| | input_ids = model_inputs['input_ids'].cuda() |
| | attention_mask = model_inputs['attention_mask'].cuda() |
| | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| | generation_config['eos_token_id'] = eos_token_id |
| | generation_output = self.generate_origin( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **generation_config |
| | ) |
| | responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) |
| | responses = [response.split(template.sep)[0].strip() for response in responses] |
| | return responses |
| |
|
| |
|
| | |
| | def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, |
| | num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', |
| | verbose=False): |
| | |
| | |
| | if history is None and pixel_values is not None and '<image>' not in question: |
| | question = '<image>\n' + question |
| |
|
| | if num_patches_list is None: |
| | num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| | assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
| |
|
| | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| | self.img_context_token_id = img_context_token_id |
| |
|
| | template = get_conv_template(self.template) |
| | template.system_message = self.system_message |
| | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) |
| |
|
| | history = [] if history is None else history |
| | for (old_question, old_answer) in history: |
| | template.append_message(template.roles[0], old_question) |
| | template.append_message(template.roles[1], old_answer) |
| | template.append_message(template.roles[0], question) |
| | template.append_message(template.roles[1], None) |
| | query = template.get_prompt() |
| |
|
| |
|
| | if verbose and pixel_values is not None: |
| | image_bs = pixel_values.shape[0] |
| |
|
| |
|
| |
|
| | |
| | for num_patches in num_patches_list: |
| | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| | query = query.replace('<image>', image_tokens, 1) |
| | print(num_patches,self.num_image_token) |
| | print(pixel_values.shape[0]) |
| | |
| | model_inputs = tokenizer(query, return_tensors='pt') |
| |
|
| | input_ids = model_inputs['input_ids'].cuda() |
| | attention_mask = model_inputs['attention_mask'].cuda() |
| |
|
| | generation_config['eos_token_id'] = eos_token_id |
| | generation_output = self.generate_origin( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **generation_config |
| | ) |
| | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| | response = response.split(template.sep)[0].strip() |
| | history.append((question, response)) |
| | if return_history: |
| | return response, history |
| | else: |
| | query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| | query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| | if verbose: |
| | print(query_to_print, response) |
| |
|
| | return response |
| | |
| | @torch.no_grad() |
| | def generate_origin( |
| | self, |
| | pixel_values: Optional[torch.FloatTensor] = None, |
| | input_ids: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | visual_features: Optional[torch.FloatTensor] = None, |
| | generation_config: Optional[GenerationConfig] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **generate_kwargs, |
| | ) -> torch.LongTensor: |
| |
|
| | assert self.img_context_token_id is not None |
| | if pixel_values is not None: |
| | if visual_features is not None: |
| | vit_embeds = visual_features |
| | else: |
| | vit_embeds = self.extract_feature(pixel_values) |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
|
| |
|
| | B, N, C = input_embeds.shape |
| | input_embeds = input_embeds.reshape(B * N, C) |
| |
|
| | input_ids = input_ids.reshape(B * N) |
| | selected = (input_ids == self.img_context_token_id) |
| | assert selected.sum() != 0 |
| | input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| | print("ID: ",self.img_context_token_id) |
| | input_embeds = input_embeds.reshape(B, N, C) |
| | else: |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| | |
| |
|
| | outputs = self.language_model.generate( |
| | inputs_embeds=input_embeds, |
| | attention_mask=attention_mask, |
| | generation_config=generation_config, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | use_cache=True, |
| | **generate_kwargs, |
| | ) |
| |
|
| | return outputs |
| | @torch.no_grad() |
| | def generate_ocr( |
| | self, |
| | pixel_values: Optional[torch.FloatTensor] = None, |
| | input_ids: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | visual_features: Optional[torch.FloatTensor] = None, |
| | generation_config: Optional[GenerationConfig] = None, |
| | reference_embeds=None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | repetition_penalty=1.5, |
| | **generate_kwargs, |
| | ) -> torch.LongTensor: |
| |
|
| | assert self.img_context_token_id is not None |
| | if pixel_values is not None: |
| | if visual_features is not None: |
| | vit_embeds = visual_features |
| | else: |
| | vit_embeds = self.extract_feature(pixel_values) |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| | |
| |
|
| | B, N, C = input_embeds.shape |
| | input_embeds = input_embeds.reshape(B * N, C) |
| |
|
| | input_ids = input_ids.reshape(B * N) |
| | selected = (input_ids == self.img_context_token_id) |
| | assert selected.sum() != 0 |
| | input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| | |
| |
|
| | if reference_embeds is not None: |
| | selected = (input_ids == 92537) |
| | assert selected.sum() != 0 |
| | input_embeds[selected] =reference_embeds.reshape(-1, C).to(input_embeds.device) |
| | |
| |
|
| | input_embeds = input_embeds.reshape(B, N, C) |
| | else: |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| | |
| | |
| |
|
| | outputs = self.language_model.generate( |
| | inputs_embeds=input_embeds, |
| | attention_mask=attention_mask, |
| | generation_config=generation_config, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | use_cache=True, |
| | repetition_penalty=repetition_penalty, |
| | **generate_kwargs, |
| | ) |
| |
|
| | return outputs |
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | pixel_values: Optional[torch.FloatTensor] = None, |
| | input_ids: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | visual_features: Optional[torch.FloatTensor] = None, |
| | generation_config: Optional[GenerationConfig] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **generate_kwargs, |
| | ) -> torch.LongTensor: |
| |
|
| | assert self.img_context_token_id is not None |
| | if pixel_values is not None: |
| | if visual_features is not None: |
| | vit_embeds = visual_features |
| | else: |
| | |
| | vit_embeds = self.extract_feature(pixel_values) |
| | |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| | |
| | vit_embeds = self.resampler(vit_embeds) |
| | |
| | |
| | mu=self.mu_sigma[:,0].reshape((-1,1)) |
| | sigma=self.mu_sigma[:,1].reshape((-1,1)) |
| |
|
| | indices=vq_cos_sim(self.normed_emb,vit_embeds).reshape((-1,)) |
| | |
| |
|
| | vit_embeds=vit_embeds.reshape((-1,vit_embeds.shape[-1]))*sigma[indices][:]+mu[indices][:] |
| | |
| | B, N, C = input_embeds.shape |
| | input_embeds = input_embeds.reshape(B * N, C) |
| |
|
| | input_ids = input_ids.reshape(B * N) |
| | selected = (input_ids == self.img_context_token_id) |
| | |
| | assert selected.sum() != 0 |
| | |
| | input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| |
|
| | |
| |
|
| | input_embeds = input_embeds.reshape(B, N, C) |
| | else: |
| | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
|
| | outputs = self.language_model.generate( |
| | inputs_embeds=input_embeds, |
| | attention_mask=attention_mask, |
| | generation_config=generation_config, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | use_cache=True, |
| | **generate_kwargs, |
| | ) |
| |
|
| | return outputs |
| |
|