import gradio as gr from classes.Perturber import Perturber from classes.Renderer import Renderer from classes.LegibilityPlot import LegibilityPlot from transformers import TrOCRProcessor, AutoModel # preprocessor provides image normalization and resizing preprocessor = TrOCRProcessor.from_pretrained( "microsoft/trocr-base-handwritten") # load the model schema and pretrained weights # (this may take some time to download) model = AutoModel.from_pretrained("dvsth/LEGIT-TrOCR-MT", revision='main', trust_remote_code=True) perturber = Perturber('trocr', 50) renderer = Renderer('unifont.ttf') plotter = LegibilityPlot() def demo(word_to_perturb, k, n): if ' ' in word_to_perturb: return 'Please enter a single word.' perturbations, metadatas, images, scores = [], [], [], [] for i in range(10): perturbation, metadata = perturber.perturb_word(word_to_perturb, k, n) inputimg = renderer.render_image(perturbation, word_to_perturb) score = model(preprocessor(inputimg, return_tensors='pt').pixel_values).item() metadata['score'] = score outputimg = renderer.render_image(perturbation, '') perturbations.append(perturbation) images.append(outputimg) metadatas.append(metadata) scores.append(score) # sort perturbations by score perturbations = [perturbation for perturbation, score in sorted(zip(perturbations, scores), key=lambda x: x[1])] scores = sorted(scores) images = [image for image, score in sorted(zip(images, scores), key=lambda x: x[1])] metadatas = [metadata for metadata, score in sorted(zip(metadatas, scores), key=lambda x: x[1])] # return as a single string in the format # perturbation1 (score1) # perturbation2 (score2) # ... # perturbationN (scoreN) # with all scores rounded to 2 decimal places ret_str = '' for i in range(len(perturbations)): ret_str += f'{perturbations[i]} ({round(scores[i], 2)}) -- ' + ("legible" if scores[i] > 0 else "not legible") + '\n' # plot the perturbations and scores fig = plotter.plot(scores, perturbations) return ret_str, fig interface = gr.Interface(fn=demo, inputs=["text", gr.Slider(1, 50, 20, step=1), gr.Slider(0., 1., 0.5)], outputs=["text", "plot"], allow_flagging='never') interface.launch(inbrowser=True)