squirelmail commited on
Commit
2f66b50
·
verified ·
1 Parent(s): ca1e5ec

Create check_model.py

Browse files
Files changed (1) hide show
  1. check_model.py +193 -0
check_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # python3 cek_model_v6.py --weights /workspace/captcha_final.weights.h5 --image /workspace/dataset_500/style7/K9NO2.png
5
+ # python3 cek_model_v6.py --weights /workspace/captcha_final.weights.h5 --data-root /workspace/dataset_500 --samples 64
6
+ # python3 cek_model_v6.py --weights captcha_final.weights.h5 --data-root /datasets/dataset_500 --samples 64
7
+
8
+
9
+ import os, re, glob, argparse, sys, time
10
+ from pathlib import Path
11
+ import numpy as np
12
+ from PIL import Image
13
+ import tensorflow as tf
14
+ from tensorflow.keras import layers, models, backend as K
15
+
16
+ # ---------------- Args ----------------
17
+ def parse_args():
18
+ p = argparse.ArgumentParser("Test inference CRNN+CTC dari weights Keras 3 (model_with_ctc.save_weights).")
19
+ p.add_argument("--weights", required=True, help="Path ke *.weights.h5 (hasil save_weights).")
20
+ p.add_argument("--image", help="Uji 1 gambar (PNG/JPG). Nama file jadi GT jika --gt tidak diisi.")
21
+ p.add_argument("--gt", help="Ground truth untuk --image (opsional, default dari nama file).")
22
+ p.add_argument("--data-root", help="Root dataset berisi style0..style59/LABEL.png untuk batch test.")
23
+ p.add_argument("--samples", type=int, default=64, help="Jumlah sampel di batch test.")
24
+ p.add_argument("--height", type=int, default=50)
25
+ p.add_argument("--width", type=int, default=250)
26
+ p.add_argument("--ext", type=str, default="png")
27
+ p.add_argument("--show", type=int, default=12, help="Banyak baris contoh yang ditampilkan.")
28
+ return p.parse_args()
29
+
30
+ # ------------- Charset & util -------------
31
+ CHARSET = list("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ")
32
+ BLANK_ID = len(CHARSET) # 36
33
+ ID2CHAR = np.array(CHARSET)
34
+
35
+ def collapse_and_strip_blanks(seq_ids, blank_id=BLANK_ID):
36
+ prev = -1; out = []
37
+ for t in seq_ids:
38
+ if t != prev and t != blank_id:
39
+ out.append(t)
40
+ prev = t
41
+ return out
42
+
43
+ def ids_to_text(ids):
44
+ ids = [i for i in ids if 0 <= i < len(CHARSET)]
45
+ return "".join(ID2CHAR[ids]) if ids else ""
46
+
47
+ def cer(pred, gt):
48
+ m, n = len(pred), len(gt)
49
+ if n == 0: return 0.0 if m == 0 else 1.0
50
+ dp = np.zeros((m+1, n+1), dtype=np.int32)
51
+ dp[:,0] = np.arange(m+1); dp[0,:] = np.arange(n+1)
52
+ for i in range(1, m+1):
53
+ for j in range(1, n+1):
54
+ dp[i,j] = min(dp[i-1,j]+1, dp[i,j-1]+1, dp[i-1,j-1] + (pred[i-1]!=gt[j-1]))
55
+ return dp[m,n] / n
56
+
57
+ # ------------- Model builders -------------
58
+ def build_models(h=50, w=250, num_classes=len(CHARSET)+1):
59
+ inp = layers.Input(shape=(h, w, 1), name="input")
60
+ x = layers.Conv2D(32, (3,3), activation="relu", padding="same")(inp)
61
+ x = layers.BatchNormalization()(x)
62
+ x = layers.MaxPooling2D((2,2))(x) # 50x250 -> 25x125
63
+
64
+ x = layers.Conv2D(64, (3,3), activation="relu", padding="same")(x)
65
+ x = layers.BatchNormalization()(x)
66
+ x = layers.MaxPooling2D((2,2))(x) # 25x125 -> 12x62
67
+
68
+ x = layers.Conv2D(128, (3,3), activation="relu", padding="same")(x)
69
+ x = layers.BatchNormalization()(x)
70
+ x = layers.MaxPooling2D((2,2))(x) # 12x62 -> 6x31
71
+
72
+ shp = K.int_shape(x) # (None, 6, 31, 128)
73
+ x = layers.Reshape((shp[2], shp[1]*shp[3]))(x) # (None, 31, 768)
74
+
75
+ x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.0, recurrent_dropout=0.0))(x)
76
+ x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.0, recurrent_dropout=0.0))(x)
77
+
78
+ pred = layers.Dense(num_classes, activation="softmax", name="predictions")(x)
79
+
80
+ # CTC inputs untuk menyamai graph training
81
+ labels = layers.Input(name="labels", shape=(None,), dtype="int32")
82
+ input_len = layers.Input(name="input_length", shape=(1,), dtype="int32")
83
+ label_len = layers.Input(name="label_length", shape=(1,), dtype="int32")
84
+ def ctc_fn(args):
85
+ y_pred, labels_t, in_l, lab_l = args
86
+ return K.ctc_batch_cost(labels_t, y_pred, in_l, lab_l)
87
+ ctc = layers.Lambda(ctc_fn, output_shape=(1,), name="ctc_loss", dtype="float32")([pred, labels, input_len, label_len])
88
+
89
+ model_with_ctc = models.Model(inputs=[inp, labels, input_len, label_len], outputs=ctc, name="crnn_ctc_train")
90
+ base_model = models.Model(inputs=inp, outputs=pred, name="crnn_ctc_base")
91
+ return model_with_ctc, base_model
92
+
93
+ # ------------- IO & preprocess -------------
94
+ def preprocess_gray(img_pil, h=50, w=250):
95
+ im = img_pil.convert("L").resize((w, h), Image.BILINEAR)
96
+ arr = np.asarray(im, dtype=np.float32) / 255.0
97
+ arr = (arr - 0.5) / 0.5
98
+ return arr[..., None] # (H,W,1)
99
+
100
+ def list_files(root, ext="png", max_n=64):
101
+ rootp = Path(root)
102
+ pat = re.compile(r"^[A-Z0-9]{5}$")
103
+ pairs = []
104
+ for sid in range(60):
105
+ d = rootp / f"style{sid}"
106
+ if not d.exists(): continue
107
+ for f in glob.glob(str(d / f"*.{ext}")):
108
+ lbl = Path(f).stem.upper()
109
+ if pat.match(lbl):
110
+ pairs.append((f, lbl))
111
+ if len(pairs) >= max_n: break
112
+ if len(pairs) >= max_n: break
113
+ return pairs
114
+
115
+ # ------------- Predict helpers -------------
116
+ def predict_batch(base_model, batch_imgs):
117
+ """batch_imgs: np.array (B,H,W,1) float32 [-1,1]"""
118
+ probs = base_model.predict(batch_imgs, verbose=0) # (B, 31, 37)
119
+ ids = np.argmax(probs, axis=-1) # (B, 31)
120
+ texts = []
121
+ for row in ids:
122
+ dec = collapse_and_strip_blanks(row, blank_id=BLANK_ID)
123
+ texts.append(ids_to_text(dec))
124
+ return texts
125
+
126
+ def main():
127
+ args = parse_args()
128
+
129
+ # (opsional) batasi threads kalau container ketat
130
+ os.environ.setdefault("TF_NUM_INTRAOP_THREADS", "1")
131
+ os.environ.setdefault("TF_NUM_INTEROP_THREADS", "1")
132
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
133
+
134
+ # 1) Bangun model & load weights
135
+ wpath = Path(args.weights)
136
+ if not wpath.exists():
137
+ print("Weights not found:", wpath); sys.exit(1)
138
+ st = wpath.stat()
139
+ print(f"Found weights: {wpath} | size: {st.st_size/1024:.1f} KB | mtime: {time.ctime(st.st_mtime)}")
140
+ print("TF GPUs:", tf.config.list_physical_devices('GPU'))
141
+
142
+ model_with_ctc, base_model = build_models(h=args.height, w=args.width, num_classes=len(CHARSET)+1)
143
+ try:
144
+ model_with_ctc.load_weights(str(wpath))
145
+ print("OK: weights loaded.")
146
+ except Exception as e:
147
+ print("Failed to load weights:", e); sys.exit(2)
148
+
149
+ print("Base output shape:", base_model.output_shape) # Expect (None, 31, 37)
150
+
151
+ # 2A) Single image test
152
+ if args.image:
153
+ f = Path(args.image)
154
+ if not f.exists():
155
+ print("Image not found:", f); sys.exit(3)
156
+ with Image.open(f) as im:
157
+ x = preprocess_gray(im, h=args.height, w=args.width)
158
+ pred = predict_batch(base_model, np.expand_dims(x, 0))[0]
159
+ gt = args.gt if args.gt else f.stem.upper()
160
+ print(f"\nSingle image:")
161
+ print(f"GT : {gt}")
162
+ print(f"PRED: {pred}")
163
+ sys.exit(0)
164
+
165
+ # 2B) Batch test dari dataset
166
+ if args.data_root:
167
+ pairs = list_files(args.data_root, ext=args.ext, max_n=args.samples)
168
+ if not pairs:
169
+ print("No valid files in dataset root."); sys.exit(0)
170
+ print(f"Testing on {len(pairs)} samples from {args.data_root} ...")
171
+ X, GT = [], []
172
+ for f, lbl in pairs:
173
+ with Image.open(f) as im:
174
+ X.append(preprocess_gray(im, h=args.height, w=args.width))
175
+ GT.append(lbl)
176
+ X = np.stack(X, 0).astype(np.float32)
177
+
178
+ PRED = predict_batch(base_model, X)
179
+ exact = np.mean([int(p == g) for p, g in zip(PRED, GT)])
180
+ cer_vals = [cer(p, g) for p, g in zip(PRED, GT)]
181
+
182
+ for i in range(min(args.show, len(PRED))):
183
+ print(f"{i:02d} GT: {GT[i]} | Pred: {PRED[i]}")
184
+
185
+ print(f"\nExact match: {exact*100:.2f}% | Mean CER: {float(np.mean(cer_vals)):.4f}\n")
186
+ print(f"Total images tested: {len(PRED)}\n")
187
+ sys.exit(0)
188
+
189
+ print("Nothing to test. Provide --image or --data-root.")
190
+ sys.exit(0)
191
+
192
+ if __name__ == "__main__":
193
+ main()