Dev Seth
init space
70b0359
raw
history blame contribute delete
No virus
2.36 kB
import gradio as gr
from classes.Perturber import Perturber
from classes.Renderer import Renderer
from classes.LegibilityPlot import LegibilityPlot
from transformers import TrOCRProcessor, AutoModel
# preprocessor provides image normalization and resizing
preprocessor = TrOCRProcessor.from_pretrained(
"microsoft/trocr-base-handwritten")
# load the model schema and pretrained weights
# (this may take some time to download)
model = AutoModel.from_pretrained("dvsth/LEGIT-TrOCR-MT", revision='main', trust_remote_code=True)
perturber = Perturber('trocr', 50)
renderer = Renderer('unifont.ttf')
plotter = LegibilityPlot()
def demo(word_to_perturb, k, n):
if ' ' in word_to_perturb:
return 'Please enter a single word.'
perturbations, metadatas, images, scores = [], [], [], []
for i in range(10):
perturbation, metadata = perturber.perturb_word(word_to_perturb, k, n)
inputimg = renderer.render_image(perturbation, word_to_perturb)
score = model(preprocessor(inputimg, return_tensors='pt').pixel_values).item()
metadata['score'] = score
outputimg = renderer.render_image(perturbation, '')
perturbations.append(perturbation)
images.append(outputimg)
metadatas.append(metadata)
scores.append(score)
# sort perturbations by score
perturbations = [perturbation for perturbation, score in sorted(zip(perturbations, scores), key=lambda x: x[1])]
scores = sorted(scores)
images = [image for image, score in sorted(zip(images, scores), key=lambda x: x[1])]
metadatas = [metadata for metadata, score in sorted(zip(metadatas, scores), key=lambda x: x[1])]
# return as a single string in the format
# perturbation1 (score1)
# perturbation2 (score2)
# ...
# perturbationN (scoreN)
# with all scores rounded to 2 decimal places
ret_str = ''
for i in range(len(perturbations)):
ret_str += f'{perturbations[i]} ({round(scores[i], 2)}) -- ' + ("legible" if scores[i] > 0 else "not legible") + '\n'
# plot the perturbations and scores
fig = plotter.plot(scores, perturbations)
return ret_str, fig
interface = gr.Interface(fn=demo, inputs=["text", gr.Slider(1, 50, 20, step=1), gr.Slider(0., 1., 0.5)], outputs=["text", "plot"], allow_flagging='never')
interface.launch(inbrowser=True)