triopood commited on
Commit
15b9ea3
1 Parent(s): 1f963ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -164,7 +164,7 @@ valid_dataset = CustomOCRDataset(
164
  )
165
 
166
  model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
167
- model.to(device)
168
  print(model)
169
  # Total parameters and trainable parameters.
170
  total_params = sum(p.numel() for p in model.parameters())
@@ -214,7 +214,7 @@ training_args = Seq2SeqTrainingArguments(
214
  evaluation_strategy='epoch',
215
  per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
216
  per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
217
- fp16=True,
218
  output_dir='seq2seq_model_printed/',
219
  logging_strategy='epoch',
220
  save_strategy='epoch',
@@ -237,7 +237,8 @@ trainer = Seq2SeqTrainer(
237
  res = trainer.train()
238
 
239
  processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
240
- trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
 
241
 
242
  def read_and_show(image_path):
243
  """
@@ -261,7 +262,8 @@ def ocr(image, processor, model):
261
  generated_text: the OCR'd text string.
262
  """
263
  # We can directly perform OCR on cropped images.
264
- pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
 
265
  generated_ids = model.generate(pixel_values)
266
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
267
  return generated_text
 
164
  )
165
 
166
  model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
167
+ #model.to(device)
168
  print(model)
169
  # Total parameters and trainable parameters.
170
  total_params = sum(p.numel() for p in model.parameters())
 
214
  evaluation_strategy='epoch',
215
  per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
216
  per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
217
+ #fp16=True,
218
  output_dir='seq2seq_model_printed/',
219
  logging_strategy='epoch',
220
  save_strategy='epoch',
 
237
  res = trainer.train()
238
 
239
  processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
240
+ trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step))
241
+ #.to(device)
242
 
243
  def read_and_show(image_path):
244
  """
 
262
  generated_text: the OCR'd text string.
263
  """
264
  # We can directly perform OCR on cropped images.
265
+ pixel_values = processor(image, return_tensors='pt').pixel_values
266
+ #.to(device)
267
  generated_ids = model.generate(pixel_values)
268
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
269
  return generated_text