hadrakey's picture
Training in progress, step 1000
9c909e3 verified
raw
history blame
4.47 kB
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import pandas as pd
from PIL import Image
from torchmetrics.text import CharErrorRate
# Finetuned model
model_finetune_1 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_1")
model_finetune_2 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_15000")
model_finetune_3 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_30000")
model_finetune_4 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_45000")
model_finetune_5 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_60000")
model_finetune_6 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_70000")
#Baseline
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
# Checked label
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
data = pd.read_csv(df_path)
data.dropna(inplace=True)
data.reset_index(inplace=True)
sample = data.iloc[:50,:]
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"
inf_baseline = []
inf_finetune_1 = []
inf_finetune_2 = []
inf_finetune_3 = []
inf_finetune_4 = []
inf_finetune_5 = []
inf_finetune_6 = []
cer_fine_1 = []
cer_fine_2 = []
cer_fine_3 = []
cer_fine_4 = []
cer_fine_5 = []
cer_fine_6 = []
cer_base = []
cer_metric = CharErrorRate()
for idx in range(len(sample)):
image = Image.open(root_dir + "final_cropped_rotated_" + sample.filename[idx]).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids_base = model_base.generate(pixel_values)
generated_ids_fine_1 = model_finetune_1.generate(pixel_values)
generated_ids_fine_2= model_finetune_2.generate(pixel_values)
generated_ids_fine_3 = model_finetune_3.generate(pixel_values)
generated_ids_fine_4 = model_finetune_4.generate(pixel_values)
generated_ids_fine_5 = model_finetune_5.generate(pixel_values)
generated_ids_fine_6 = model_finetune_6.generate(pixel_values)
generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
generated_text_fine_1= processor.batch_decode(generated_ids_fine_1, skip_special_tokens=True)[0]
generated_text_fine_2= processor.batch_decode(generated_ids_fine_2, skip_special_tokens=True)[0]
generated_text_fine_3= processor.batch_decode(generated_ids_fine_3, skip_special_tokens=True)[0]
generated_text_fine_4= processor.batch_decode(generated_ids_fine_4, skip_special_tokens=True)[0]
generated_text_fine_5= processor.batch_decode(generated_ids_fine_5, skip_special_tokens=True)[0]
generated_text_fine_6= processor.batch_decode(generated_ids_fine_6, skip_special_tokens=True)[0]
cer_fine_1.append(cer_metric(generated_text_fine_1.lower(), sample.text[idx].lower()).detach().numpy())
cer_fine_2.append(cer_metric(generated_text_fine_2.lower(), sample.text[idx].lower()).detach().numpy())
cer_fine_3.append(cer_metric(generated_text_fine_3.lower(), sample.text[idx].lower()).detach().numpy())
cer_fine_4.append(cer_metric(generated_text_fine_4.lower(), sample.text[idx].lower()).detach().numpy())
cer_fine_5.append(cer_metric(generated_text_fine_5.lower(), sample.text[idx].lower()).detach().numpy())
cer_fine_6.append(cer_metric(generated_text_fine_6.lower(), sample.text[idx].lower()).detach().numpy())
cer_base.append(cer_metric(generated_text_base.lower(), sample.text[idx].lower()).detach().numpy())
inf_baseline.append(generated_text_base)
inf_finetune_1.append(generated_text_fine_1)
inf_finetune_2.append(generated_text_fine_2)
inf_finetune_3.append(generated_text_fine_3)
inf_finetune_4.append(generated_text_fine_4)
inf_finetune_5.append(generated_text_fine_5)
inf_finetune_6.append(generated_text_fine_6)
sample["Baseline"]=inf_baseline
sample["Finetune_1"]=inf_finetune_1
sample["Finetune_2"]=inf_finetune_2
sample["Finetune_3"]=inf_finetune_3
sample["Finetune_4"]=inf_finetune_4
sample["Finetune_5"]=inf_finetune_5
sample["Finetune_6"]=inf_finetune_6
sample["cer_1"]=cer_fine_1
sample["cer_2"]=cer_fine_2
sample["cer_3"]=cer_fine_3
sample["cer_4"]=cer_fine_4
sample["cer_5"]=cer_fine_5
sample["cer_6"]=cer_fine_6
sample["cer_base"]=cer_base
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "inference_results.csv")