File size: 3,678 Bytes
b585c7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
But accepts preloaded model to avoid slowness in use and CUDA forking issues

Loader that uses Pix2Struct models to image caption

"""
from typing import List, Union, Any, Tuple

from langchain.docstore.document import Document
from langchain.document_loaders import ImageCaptionLoader
from utils import get_device, clear_torch_cache
from PIL import Image


class H2OPix2StructLoader(ImageCaptionLoader):
    """Loader that extracts text from images"""

    def __init__(self, path_images: Union[str, List[str]] = None, model_type="google/pix2struct-textcaps-base",
                 max_new_tokens=50):
        super().__init__(path_images)
        self._pix2struct_model = None
        self._model_type = model_type
        self._max_new_tokens = max_new_tokens

    def set_context(self):
        if get_device() == 'cuda':
            import torch
            n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
            if n_gpus > 0:
                self.context_class = torch.device
                self.device = 'cuda'
            else:
                self.device = 'cpu'
        else:
            self.device = 'cpu'

    def load_model(self):
        try:
            from transformers import AutoProcessor, Pix2StructForConditionalGeneration
        except ImportError:
            raise ValueError(
                "`transformers` package not found, please install with "
                "`pip install transformers`."
            )
        if self._pix2struct_model:
            self._pix2struct_model = self._pix2struct_model.to(self.device)
            return self
        self.set_context()
        self._pix2struct_processor = AutoProcessor.from_pretrained(self._model_type)
        self._pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(self._model_type).to(self.device)
        return self

    def unload_model(self):
        if hasattr(self._pix2struct_model, 'cpu'):
            self._pix2struct_model.cpu()
            clear_torch_cache()

    def set_image_paths(self, path_images: Union[str, List[str]]):
        """
        Load from a list of image files
        """
        if isinstance(path_images, str):
            self.image_paths = [path_images]
        else:
            self.image_paths = path_images

    def load(self, prompt=None) -> List[Document]:
        if self._pix2struct_model is None:
            self.load_model()
        results = []
        for path_image in self.image_paths:
            caption, metadata = self._get_captions_and_metadata(
                processor=self._pix2struct_processor, model=self._pix2struct_model, path_image=path_image
            )
            doc = Document(page_content=caption, metadata=metadata)
            results.append(doc)

        return results

    def _get_captions_and_metadata(
            self, processor: Any, model: Any, path_image: str) -> Tuple[str, dict]:
        """
        Helper function for getting the captions and metadata of an image
        """
        try:
            image = Image.open(path_image)
        except Exception:
            raise ValueError(f"Could not get image data for {path_image}")
        inputs = self._pix2struct_processor(images=image, return_tensors="pt")
        inputs = inputs.to(self.device)
        generated_ids = self._pix2struct_model.generate(**inputs, max_new_tokens=self._max_new_tokens)
        generated_text = self._pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        metadata: dict = {"image_path": path_image}
        return generated_text, metadata