File size: 2,373 Bytes
bfea304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
from torchvision import transforms as T
from PIL import Image, ImageEnhance
from io import BytesIO

MODEL_PATH = '/var/task/torch/hub/baudm_parseq_main' if os.path.isfile('/var/task/torch/hub/baudm_parseq_main') else 'torch/hub/baudm_parseq_main'
class OCRModel:
    def __init__(self):
        # Load the model
        print(MODEL_PATH)
        self.model = torch.hub.load(MODEL_PATH, 'parseq', source='local', pretrained=True, trust_repo=True).eval()

        # Preprocess transformation
        self._preprocess = T.Compose([
            T.Resize((32, 128), T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])

    def adjust_image(self, image, brightness=1.0, contrast=1.0, sharpness=1.0):
        """
        Adjust the brightness, contrast, and sharpness of the image.
        """
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(brightness)

        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(contrast)

        enhancer = ImageEnhance.Sharpness(image)
        image = enhancer.enhance(sharpness)

        return image

    def predict(self, image_input, brightness=1.0, contrast=1.0, sharpness=1.0):
        """
        Predict text from an image. The image can be provided as a file path or a buffer.
        """
        if isinstance(image_input, bytes):
            image = Image.open(BytesIO(image_input)).convert('RGB')
        else:
            image = Image.open(image_input).convert('RGB')

        # Adjust the image according to user-defined values
        image = self.adjust_image(image, brightness, contrast, sharpness)
        image.save('adjusted_image.jpg')
        # Preprocess the image
        image = self._preprocess(image).unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            pred = self.model(image).softmax(-1)
            label, _ = self.model.tokenizer.decode(pred)
        
        return label[0]

# Example usage
# if __name__ == '__main__':
#     ocr_model = OCRModel()  # Instantiate the class
#     with open('../../../Desktop/DFaqQf.png', 'rb') as image_file:
#         image_buffer = image_file.read()
#     result = ocr_model.predict(image_buffer, brightness=1.2, contrast=5.3, sharpness=2.1)  # Example with adjusted values
#     print("Detected Text:", result)