7eu7d7 commited on
Commit
f535d79
·
1 Parent(s): 75f5c93
Files changed (4) hide show
  1. app.py +24 -13
  2. cap.py +6 -3
  3. models/__init__.py +1 -1
  4. models/enc_dec.py +22 -0
app.py CHANGED
@@ -13,33 +13,33 @@ def load_predictor(model):
13
  predictor = Predictor(hf_hub_download(
14
  f'7eu7d7/CAPTCHA_recognize',
15
  model,
16
- ))
17
  return predictor
18
 
19
 
20
- def process_image(image):
21
  """
22
- Process the uploaded image - this is an example function
23
- You can modify this function to implement specific image processing logic
24
  """
25
  if image is None:
26
  return "Please upload an image first"
27
 
28
- # Example processing: convert image to grayscale
29
  if isinstance(image, np.ndarray):
30
- # If it's a numpy array, convert to PIL Image
31
  img = Image.fromarray(image.astype('uint8')).convert('RGB')
32
  else:
33
  img = image.convert('RGB')
34
 
35
- predictor = load_predictor('captcha-7400.safetensors')
36
- text = predictor.pred_img(img, show=False)
37
- return text
 
 
 
38
 
39
 
40
  # Create Gradio interface
41
  with gr.Blocks(title="CAPTCHA Recognize") as demo:
42
-
43
  with gr.Row():
44
  # Left column - Input area
45
  with gr.Column(scale=1):
@@ -49,6 +49,18 @@ with gr.Blocks(title="CAPTCHA Recognize") as demo:
49
  height=300
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Run button
53
  process_btn = gr.Button(
54
  "Run",
@@ -67,11 +79,10 @@ with gr.Blocks(title="CAPTCHA Recognize") as demo:
67
  # Bind events
68
  process_btn.click(
69
  fn=process_image,
70
- inputs=image_input,
71
  outputs=[text_output]
72
  )
73
 
74
-
75
  # Launch the application
76
  if __name__ == "__main__":
77
- demo.launch()
 
13
  predictor = Predictor(hf_hub_download(
14
  f'7eu7d7/CAPTCHA_recognize',
15
  model,
16
+ ), ckpt_name=model)
17
  return predictor
18
 
19
 
20
+ def process_image(image, model_name):
21
  """
22
+ Process the uploaded image with selected model
 
23
  """
24
  if image is None:
25
  return "Please upload an image first"
26
 
27
+ # Convert image to PIL format if needed
28
  if isinstance(image, np.ndarray):
 
29
  img = Image.fromarray(image.astype('uint8')).convert('RGB')
30
  else:
31
  img = image.convert('RGB')
32
 
33
+ try:
34
+ predictor = load_predictor(model_name)
35
+ text = predictor.pred_img(img, show=False)
36
+ return text
37
+ except Exception as e:
38
+ return f"Error processing image: {str(e)}"
39
 
40
 
41
  # Create Gradio interface
42
  with gr.Blocks(title="CAPTCHA Recognize") as demo:
 
43
  with gr.Row():
44
  # Left column - Input area
45
  with gr.Column(scale=1):
 
49
  height=300
50
  )
51
 
52
+ # Model selection dropdown
53
+ model_dropdown = gr.Dropdown(
54
+ label="Select Model",
55
+ choices=[
56
+ "captcha-2000.safetensors",
57
+ "captcha-7400.safetensors",
58
+ "captcha-caformer-v2-6200.safetensors",
59
+ ],
60
+ value="captcha-caformer-v2-6200.safetensors", # 默认选择
61
+ interactive=True
62
+ )
63
+
64
  # Run button
65
  process_btn = gr.Button(
66
  "Run",
 
79
  # Bind events
80
  process_btn.click(
81
  fn=process_image,
82
+ inputs=[image_input, model_dropdown],
83
  outputs=[text_output]
84
  )
85
 
 
86
  # Launch the application
87
  if __name__ == "__main__":
88
+ demo.launch()
cap.py CHANGED
@@ -1,7 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  import torch
3
  import argparse
4
- from models import ResnetEncoderDecoder
5
  from utils import remove_rptch
6
  from safetensors import safe_open
7
  from torchvision import transforms as T
@@ -14,8 +14,11 @@ char_dict_pp = '_0123456789abcdefghijklmnopqrstuvwxyz()+-*/='
14
 
15
 
16
  class Predictor:
17
- def __init__(self, model_path, char_dict=char_dict_pp):
18
- self.model = ResnetEncoderDecoder(char_dict).to(device)
 
 
 
19
  self.model.eval()
20
  if str(device)=='cpu':
21
  check_point = self.load_safetensor(model_path, map_location='cpu')
 
1
  # -*- coding: utf-8 -*-
2
  import torch
3
  import argparse
4
+ from models import ResnetEncoderDecoder, CaformerEncoderDecoder
5
  from utils import remove_rptch
6
  from safetensors import safe_open
7
  from torchvision import transforms as T
 
14
 
15
 
16
  class Predictor:
17
+ def __init__(self, model_path, ckpt_name, char_dict=char_dict_pp):
18
+ if 'caformer' in ckpt_name:
19
+ self.model = CaformerEncoderDecoder(char_dict).to(device)
20
+ else:
21
+ self.model = ResnetEncoderDecoder(char_dict).to(device)
22
  self.model.eval()
23
  if str(device)=='cpu':
24
  check_point = self.load_safetensor(model_path, map_location='cpu')
models/__init__.py CHANGED
@@ -1 +1 @@
1
- from .enc_dec import ResnetEncoderDecoder
 
1
+ from .enc_dec import ResnetEncoderDecoder, CaformerEncoderDecoder
models/enc_dec.py CHANGED
@@ -26,3 +26,25 @@ class ResnetEncoderDecoder(nn.Module):
26
  input = F.softmax(self.out(input), dim=-1)
27
 
28
  return input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  input = F.softmax(self.out(input), dim=-1)
27
 
28
  return input
29
+
30
+ class CaformerEncoderDecoder(nn.Module):
31
+ def __init__(self, char_dict, drop_rate=0.2, drop_path_rate=0.3):
32
+ super().__init__()
33
+ self.bn = nn.BatchNorm2d(64)
34
+ backbone = timm.create_model('caformer_s18.sail_in22k_ft_in1k', pretrained=True, drop_rate=drop_rate, drop_path_rate=drop_path_rate)
35
+ backbone.set_grad_checkpointing(True)
36
+ self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
37
+ self.cnn = nn.Sequential(*list(backbone.children())[1:-1])
38
+ self.out = nn.Linear(512, len(char_dict))
39
+
40
+ self.char_dict = char_dict
41
+
42
+ def forward(self, input):
43
+ input = F.silu(self.bn(self.conv(input)), True)
44
+ input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2))
45
+ input = self.cnn(input)
46
+
47
+ input = input.permute(0, 2, 3, 1)
48
+ input = F.softmax(self.out(input), dim=-1)
49
+
50
+ return input