Ransaka commited on
Commit
93ea391
1 Parent(s): 7f200cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -6,6 +6,7 @@ from torchvision import transforms
6
  from torchvision.transforms import functional as TF
7
  from PIL import Image
8
  from sinlib import Tokenizer
 
9
 
10
  MAX_LENGTH = 32
11
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -61,14 +62,12 @@ class CRNN(nn.Module):
61
  return output
62
 
63
  @st.cache_resource
64
- def load_model():
65
  model = CRNN(num_chars=len(tokenizer))
66
- model.load_state_dict(torch.load('checkpoint-with-cer-0.18952566385269165.pth', map_location=torch.device('cpu')))
67
  model.eval()
68
  return model
69
 
70
- model = load_model()
71
-
72
  def preprocess_image(image):
73
  transform = transforms.Compose([
74
  transforms.Grayscale(),
@@ -97,6 +96,9 @@ st.warning("**Note**: This model was trained on images with these settings, \
97
  with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \
98
  For better results, use images within these limitations."
99
  )
 
 
 
100
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
101
 
102
  if uploaded_file is not None:
 
6
  from torchvision.transforms import functional as TF
7
  from PIL import Image
8
  from sinlib import Tokenizer
9
+ from pathlib import Path
10
 
11
  MAX_LENGTH = 32
12
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
62
  return output
63
 
64
  @st.cache_resource
65
+ def load_model(selected_model_path):
66
  model = CRNN(num_chars=len(tokenizer))
67
+ model.load_state_dict(torch.load(f'{selected_model_path}', map_location=torch.device('cpu')))
68
  model.eval()
69
  return model
70
 
 
 
71
  def preprocess_image(image):
72
  transform = transforms.Compose([
73
  transforms.Grayscale(),
 
96
  with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \
97
  For better results, use images within these limitations."
98
  )
99
+ fp = Path(".").glob("*.pth")
100
+ selected_model_path = st.selectbox(label="Select Model...", options=fp)
101
+ model = load_model(selected_model_path)
102
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
103
 
104
  if uploaded_file is not None: