File size: 3,667 Bytes
c426a27
9a84ec8
c426a27
 
9a84ec8
ff883a7
f1a2810
c426a27
108f2df
c426a27
 
 
 
 
 
 
 
ff883a7
 
 
 
 
c426a27
ccb14a3
 
 
 
 
 
 
 
f1a2810
ccb14a3
c426a27
ccb14a3
 
108f2df
c426a27
ccb14a3
 
 
 
 
 
 
 
 
c426a27
 
 
 
 
 
 
 
 
 
 
 
ccb14a3
 
c426a27
 
 
 
 
9a84ec8
c426a27
 
 
af88c78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from PIL import Image
import numpy as np
from typing import Union
from transformers import AutoProcessor, Blip2ForConditionalGeneration

from caption_anything.utils.utils import is_platform_win, load_image
from .base_captioner import BaseCaptioner
import time

class BLIP2Captioner(BaseCaptioner):
    def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
        super().__init__(device, enable_filter)
        self.device = device
        self.dialogue = dialogue
        self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
        self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
        if is_platform_win():
            self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
        else:
            self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)

    @torch.no_grad()
    def inference(self, 
                  image: Union[np.ndarray, Image.Image, str], 
                  filter=False, 
                  args={}):
        args['return_ppl'] = args.get('return_ppl', False)
        args['text_prompt'] = args.get('text_prompt', 'Question: what does the image show? Answer:')
        args['reference_caption'] = args.get('reference_caption', [])
        
        image = load_image(image, return_type="pil")
        result = {}
        if not self.dialogue:
            inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
            out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
            caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
            if self.enable_filter and filter:
                print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
                clip_score = self.filter_caption(image, caption, args['reference_caption'])
                result['clip_score'] = clip_score
            if args['return_ppl']:
                ppl_score = torch.stack(out.scores, dim=1).softmax(dim=2).log().max(dim=2)[0].sum(dim=1)[0]
                result['ppl_score'] = ppl_score.item()
            print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {caption}")
            result['caption'] = caption
            return result
        else:
            context = []
            template = "Question: {} Answer: {}."
            while(True):
                input_texts = input()
                if input_texts == 'end':
                    break
                prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
                inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
                out = self.model.generate(**inputs, max_new_tokens=50)
                captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
                context.append((input_texts, captions))
                result['caption'] = captions
            return result

if __name__ == '__main__':

    dialogue = False
    model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
    image_path = 'test_images/img2.jpg'
    seg_mask = np.zeros((224,224))
    seg_mask[50:200, 50:200] = 1
    print(f'process image {image_path}')
    print(model.inference_seg(image_path, seg_mask))