Update app.py
Browse files
app.py
CHANGED
@@ -164,7 +164,7 @@ valid_dataset = CustomOCRDataset(
|
|
164 |
)
|
165 |
|
166 |
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
|
167 |
-
model.to(device)
|
168 |
print(model)
|
169 |
# Total parameters and trainable parameters.
|
170 |
total_params = sum(p.numel() for p in model.parameters())
|
@@ -214,7 +214,7 @@ training_args = Seq2SeqTrainingArguments(
|
|
214 |
evaluation_strategy='epoch',
|
215 |
per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
|
216 |
per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
|
217 |
-
fp16=True,
|
218 |
output_dir='seq2seq_model_printed/',
|
219 |
logging_strategy='epoch',
|
220 |
save_strategy='epoch',
|
@@ -237,7 +237,8 @@ trainer = Seq2SeqTrainer(
|
|
237 |
res = trainer.train()
|
238 |
|
239 |
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
|
240 |
-
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step))
|
|
|
241 |
|
242 |
def read_and_show(image_path):
|
243 |
"""
|
@@ -261,7 +262,8 @@ def ocr(image, processor, model):
|
|
261 |
generated_text: the OCR'd text string.
|
262 |
"""
|
263 |
# We can directly perform OCR on cropped images.
|
264 |
-
pixel_values = processor(image, return_tensors='pt').pixel_values
|
|
|
265 |
generated_ids = model.generate(pixel_values)
|
266 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
267 |
return generated_text
|
|
|
164 |
)
|
165 |
|
166 |
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
|
167 |
+
#model.to(device)
|
168 |
print(model)
|
169 |
# Total parameters and trainable parameters.
|
170 |
total_params = sum(p.numel() for p in model.parameters())
|
|
|
214 |
evaluation_strategy='epoch',
|
215 |
per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
|
216 |
per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
|
217 |
+
#fp16=True,
|
218 |
output_dir='seq2seq_model_printed/',
|
219 |
logging_strategy='epoch',
|
220 |
save_strategy='epoch',
|
|
|
237 |
res = trainer.train()
|
238 |
|
239 |
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
|
240 |
+
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step))
|
241 |
+
#.to(device)
|
242 |
|
243 |
def read_and_show(image_path):
|
244 |
"""
|
|
|
262 |
generated_text: the OCR'd text string.
|
263 |
"""
|
264 |
# We can directly perform OCR on cropped images.
|
265 |
+
pixel_values = processor(image, return_tensors='pt').pixel_values
|
266 |
+
#.to(device)
|
267 |
generated_ids = model.generate(pixel_values)
|
268 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
269 |
return generated_text
|