| | import sys |
| |
|
| | sys.path.append("/repository/BLIP") |
| | sys.path.append("/repository") |
| | sys.path.append("BLIP") |
| |
|
| | from typing import Dict, List, Any |
| | from PIL import Image |
| | import requests |
| | import torch |
| | import base64 |
| | import os |
| | from io import BytesIO |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import InterpolationMode |
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | import torch.nn.functional as F |
| | import torch |
| | import numpy as np |
| | from BLIP.models.blip import BLIP_Base,load_checkpoint |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, pretrain_path="/repository/blip_base.pth"): |
| | image_size = 224 |
| | |
| | self.blip_encoder_path = pretrain_path |
| | self.blip_encoder = BLIP_Base( |
| | image_size=image_size, |
| | vit='base', |
| | med_config='/repository/configs/med_config.json' |
| | ) |
| | load_checkpoint(self.blip_encoder, "/repository/blip_base.pth") |
| | self.blip_encoder.eval() |
| | self.blip_encoder = self.blip_encoder.to(device) |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
| | ]) |
| | |
| |
|
| |
|
| | def __call__(self, data: Any) -> Dict[str, List[float]]: |
| | """ |
| | Args: |
| | data (:obj:): |
| | includes the input data and the parameters for the inference. |
| | Return: |
| | A :obj:`dict`:. The object returned should be a dict like {"feature_vector": [0.6331314444541931,0.8802216053009033,...,-0.7866355180740356,]} containing : |
| | - "feature_vector": A list of floats corresponding to the image embedding. |
| | """ |
| | inputs = data.pop("inputs", data) |
| | parameters = data.pop("parameters", {"mode": "image"}) |
| |
|
| | |
| | text = [inputs["text"]] |
| |
|
| | has_image = inputs["has_image"] |
| | |
| | |
| | image = Image.open(BytesIO(base64.b64decode(inputs['image']))) if has_image else Image.new("RGB", (224,224)) |
| | image = self.transform(image).unsqueeze(0).to(device) |
| | |
| | |
| | image_embeds = self.blip_encoder.visual_encoder(image) |
| | |
| | |
| | if not has_image: |
| | target_len = image_embeds.shape[1] |
| | |
| | pad = " [PAD]" * (target_len - 2) |
| | |
| | text = text + [pad] |
| |
|
| | text = self.blip_encoder.tokenizer(tuple(text), return_tensors="pt", padding="longest").to(device) |
| |
|
| | |
| | text_input_ids = text.input_ids[:1] |
| | text_attention_mask = text.attention_mask[:1] |
| |
|
| | |
| | if not has_image: |
| | text_embeds = self.blip_encoder.text_encoder(text_input_ids, attention_mask = text_attention_mask, |
| | return_dict = True, mode = 'text') |
| | image_embeds[0] = text_embeds.last_hidden_state[0] |
| |
|
| | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) |
| |
|
| | |
| | text.input_ids[:,0] = self.blip_encoder.tokenizer.enc_token_id |
| |
|
| | |
| | output = self.blip_encoder.text_encoder(text_input_ids, |
| | attention_mask = text_attention_mask, |
| | encoder_hidden_states = image_embeds, |
| | encoder_attention_mask = image_atts, |
| | return_dict=True |
| | ) |
| | |
| | hidden_state = output.last_hidden_state |
| | hidden_state = hidden_state[:,0,:].detach() |
| | |
| | return {"feature_vector": hidden_state.tolist()} |