triopood commited on
Commit
222d3e4
1 Parent(s): 15b9ea3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -35,16 +35,16 @@ dataset = project.version(1).download("folder")
35
  subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt'])
36
  subprocess.run(['unzip', 'filetxt'])
37
 
38
- # def seed_everything(seed_value):
39
- # np.random.seed(seed_value)
40
- # torch.manual_seed(seed_value)
41
- # torch.cuda.manual_seed_all(seed_value)
42
- # torch.backends.cudnn.deterministic = True
43
- # torch.backends.cudnn.benchmark = False
44
 
45
- # seed_everything(42)
46
 
47
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
 
49
  def download_and_unzip(url, save_path):
50
  print(f"Downloading and extracting assets....", end="")
@@ -164,7 +164,7 @@ valid_dataset = CustomOCRDataset(
164
  )
165
 
166
  model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
167
- #model.to(device)
168
  print(model)
169
  # Total parameters and trainable parameters.
170
  total_params = sum(p.numel() for p in model.parameters())
@@ -237,8 +237,7 @@ trainer = Seq2SeqTrainer(
237
  res = trainer.train()
238
 
239
  processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
240
- trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step))
241
- #.to(device)
242
 
243
  def read_and_show(image_path):
244
  """
@@ -262,8 +261,7 @@ def ocr(image, processor, model):
262
  generated_text: the OCR'd text string.
263
  """
264
  # We can directly perform OCR on cropped images.
265
- pixel_values = processor(image, return_tensors='pt').pixel_values
266
- #.to(device)
267
  generated_ids = model.generate(pixel_values)
268
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
269
  return generated_text
 
35
  subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt'])
36
  subprocess.run(['unzip', 'filetxt'])
37
 
38
+ def seed_everything(seed_value):
39
+ np.random.seed(seed_value)
40
+ torch.manual_seed(seed_value)
41
+ torch.cuda.manual_seed_all(seed_value)
42
+ torch.backends.cudnn.deterministic = True
43
+ torch.backends.cudnn.benchmark = False
44
 
45
+ seed_everything(42)
46
 
47
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
 
49
  def download_and_unzip(url, save_path):
50
  print(f"Downloading and extracting assets....", end="")
 
164
  )
165
 
166
  model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
167
+ model.to(device)
168
  print(model)
169
  # Total parameters and trainable parameters.
170
  total_params = sum(p.numel() for p in model.parameters())
 
237
  res = trainer.train()
238
 
239
  processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
240
+ trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
 
241
 
242
  def read_and_show(image_path):
243
  """
 
261
  generated_text: the OCR'd text string.
262
  """
263
  # We can directly perform OCR on cropped images.
264
+ pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
 
265
  generated_ids = model.generate(pixel_values)
266
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
267
  return generated_text