File size: 988 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torchmetrics import Metric
from ocr import OCR
import Levenshtein


class CharacterErrorRate(Metric):
    def __init__(self, ocr, dist_sync_on_step=False):
        # super().__init__(dist_sync_on_step=dist_sync_on_step)
        super().__init__()
        self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total_chars", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.ocr = ocr

    def update(self, pred_images, target_images):
        for pred_img, target_img in zip(pred_images, target_images):
            pred_text = self.ocr.predict(pred_img)
            target_text = self.ocr.predict(target_img)

            dist = Levenshtein.distance(pred_text, target_text)
            self.total_errors += dist
            self.total_chars += len(target_text)

    def compute(self):
        if self.total_chars == 0:
            return torch.tensor(0.0)
        return self.total_errors / self.total_chars