Spaces:
Runtime error
Runtime error
File size: 2,373 Bytes
bfea304 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import torch
from torchvision import transforms as T
from PIL import Image, ImageEnhance
from io import BytesIO
MODEL_PATH = '/var/task/torch/hub/baudm_parseq_main' if os.path.isfile('/var/task/torch/hub/baudm_parseq_main') else 'torch/hub/baudm_parseq_main'
class OCRModel:
def __init__(self):
# Load the model
print(MODEL_PATH)
self.model = torch.hub.load(MODEL_PATH, 'parseq', source='local', pretrained=True, trust_repo=True).eval()
# Preprocess transformation
self._preprocess = T.Compose([
T.Resize((32, 128), T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5)
])
def adjust_image(self, image, brightness=1.0, contrast=1.0, sharpness=1.0):
"""
Adjust the brightness, contrast, and sharpness of the image.
"""
enhancer = ImageEnhance.Brightness(image)
image = enhancer.enhance(brightness)
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(contrast)
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(sharpness)
return image
def predict(self, image_input, brightness=1.0, contrast=1.0, sharpness=1.0):
"""
Predict text from an image. The image can be provided as a file path or a buffer.
"""
if isinstance(image_input, bytes):
image = Image.open(BytesIO(image_input)).convert('RGB')
else:
image = Image.open(image_input).convert('RGB')
# Adjust the image according to user-defined values
image = self.adjust_image(image, brightness, contrast, sharpness)
image.save('adjusted_image.jpg')
# Preprocess the image
image = self._preprocess(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
pred = self.model(image).softmax(-1)
label, _ = self.model.tokenizer.decode(pred)
return label[0]
# Example usage
# if __name__ == '__main__':
# ocr_model = OCRModel() # Instantiate the class
# with open('../../../Desktop/DFaqQf.png', 'rb') as image_file:
# image_buffer = image_file.read()
# result = ocr_model.predict(image_buffer, brightness=1.2, contrast=5.3, sharpness=2.1) # Example with adjusted values
# print("Detected Text:", result)
|