| import string |
| import torch |
| import torchvision.transforms.functional as F |
| from transformers.processing_utils import ProcessorMixin |
|
|
| class CaptchaProcessor(ProcessorMixin): |
| attributes = [] |
| def __init__(self, vocab=None, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab = vocab or (string.ascii_lowercase + string.ascii_uppercase + string.digits) |
| self.idx_to_char = {i + 1: c for i, c in enumerate(self.vocab)} |
| self.idx_to_char[0] = "" |
|
|
| def __call__(self, images): |
| """ |
| Converts PIL images to the tensor format the CRNN expects. |
| """ |
| if not isinstance(images, list): |
| images = [images] |
| |
| processed_images = [] |
| for img in images: |
| |
| img = img.convert("L") |
| |
| img = img.resize((150, 40)) |
| |
| img_tensor = F.to_tensor(img) |
| processed_images.append(img_tensor) |
| |
| return {"pixel_values": torch.stack(processed_images)} |
|
|
| def batch_decode(self, logits): |
| """ |
| CTC decoding logic. |
| """ |
| tokens = torch.argmax(logits, dim=-1) |
| if len(tokens.shape) == 1: |
| tokens = tokens.unsqueeze(0) |
| |
| decoded_strings = [] |
| for batch_item in tokens: |
| char_list = [] |
| for i in range(len(batch_item)): |
| token = batch_item[i].item() |
| if token != 0: |
| if i > 0 and batch_item[i] == batch_item[i - 1]: |
| continue |
| char_list.append(self.idx_to_char.get(token, "")) |
| decoded_strings.append("".join(char_list)) |
| return decoded_strings |