Gabor Cselle commited on
Commit
2e58968
1 Parent(s): ea56d2d

It does help if we save the model :-)

Browse files
Files changed (1) hide show
  1. train_font_identifier.py +4 -1
train_font_identifier.py CHANGED
@@ -13,7 +13,7 @@ data_dir = './train_test_images'
13
 
14
  # Transformations for the image data
15
  data_transforms = transforms.Compose([
16
- s transforms.Grayscale(num_output_channels=3), # Convert images to grayscale with 3 channels
17
  transforms.Resize((224, 224)), # Resize images to the expected input size of the model
18
  transforms.ToTensor(), # Convert images to PyTorch tensors
19
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
@@ -88,3 +88,6 @@ for epoch in range(num_epochs):
88
  train_loss = train_step(model, dataloaders["train"], criterion, optimizer)
89
  val_loss, val_accuracy = validate(model, dataloaders["test"], criterion)
90
  print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
 
 
 
 
13
 
14
  # Transformations for the image data
15
  data_transforms = transforms.Compose([
16
+ transforms.Grayscale(num_output_channels=3), # Convert images to grayscale with 3 channels
17
  transforms.Resize((224, 224)), # Resize images to the expected input size of the model
18
  transforms.ToTensor(), # Convert images to PyTorch tensors
19
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
 
88
  train_loss = train_step(model, dataloaders["train"], criterion, optimizer)
89
  val_loss, val_accuracy = validate(model, dataloaders["test"], criterion)
90
  print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
91
+
92
+ # Save the model to disk
93
+ torch.save(model.state_dict(), 'font_identifier_model.pth')