File size: 2,364 Bytes
50aa037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from classes.Perturber import Perturber
from classes.Renderer import Renderer
from classes.LegibilityPlot import LegibilityPlot

from transformers import TrOCRProcessor, AutoModel

# preprocessor provides image normalization and resizing
preprocessor = TrOCRProcessor.from_pretrained(
    "microsoft/trocr-base-handwritten")

# load the model schema and pretrained weights
# (this may take some time to download)
model = AutoModel.from_pretrained("dvsth/LEGIT-TrOCR-MT", revision='main', trust_remote_code=True)

perturber = Perturber('trocr', 50)
renderer = Renderer('unifont.ttf')
plotter = LegibilityPlot()

def demo(word_to_perturb, k, n):
    if ' ' in word_to_perturb:
        return 'Please enter a single word.'

    perturbations, metadatas, images, scores = [], [], [], []
    for i in range(10):
        perturbation, metadata = perturber.perturb_word(word_to_perturb, k, n)
        inputimg = renderer.render_image(perturbation, word_to_perturb)
        score = model(preprocessor(inputimg, return_tensors='pt').pixel_values).item()
        
        metadata['score'] = score
        outputimg = renderer.render_image(perturbation, '')

        perturbations.append(perturbation)
        images.append(outputimg)
        metadatas.append(metadata)
        scores.append(score)

    # sort perturbations by score
    perturbations = [perturbation for perturbation, score in sorted(zip(perturbations, scores), key=lambda x: x[1])]
    scores = sorted(scores)
    images = [image for image, score in sorted(zip(images, scores), key=lambda x: x[1])]
    metadatas = [metadata for metadata, score in sorted(zip(metadatas, scores), key=lambda x: x[1])]

    # return as a single string in the format
    # perturbation1 (score1)
    # perturbation2 (score2)
    # ...
    # perturbationN (scoreN)
    # with all scores rounded to 2 decimal places
    ret_str = ''
    for i in range(len(perturbations)):
        ret_str += f'{perturbations[i]} ({round(scores[i], 2)}) -- ' + ("legible" if scores[i] > 0 else "not legible") + '\n'
        
    # plot the perturbations and scores
    fig = plotter.plot(scores, perturbations)
    return ret_str, fig

interface = gr.Interface(fn=demo, inputs=["text", gr.Slider(1, 50, 20, step=1), gr.Slider(0., 1., 0.5)], outputs=["text", "plot"], allow_flagging='never')

interface.launch(inbrowser=True)