Gizachew commited on
Commit
3ec26e4
·
verified ·
1 Parent(s): ae0e027

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -27
app.py CHANGED
@@ -2,20 +2,20 @@
2
 
3
  import gradio as gr
4
  import torch
5
- from PIL import Image, ImageDraw, ImageFont
6
  from model import load_model
7
  from utils import preprocess_image, decode_predictions
8
  import os
9
 
10
  # Load the model (ensure the path is correct)
11
  MODEL_PATH = "finetuned_recog_model.pth"
12
- FONT_PATH = "NotoSansEthiopic-Regular.ttf" # Update the path to your font
13
 
14
  # Check if model file exists
15
  if not os.path.exists(MODEL_PATH):
16
  raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
17
 
18
- # Check if font file exists
19
  if not os.path.exists(FONT_PATH):
20
  raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
21
 
@@ -23,17 +23,13 @@ if not os.path.exists(FONT_PATH):
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
  model = load_model(MODEL_PATH, device=device)
25
 
26
- # Load the font for rendering Amharic text
27
- from matplotlib import font_manager as fm
28
- import matplotlib.pyplot as plt
29
-
30
- ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
31
- pil_font = ImageFont.truetype(FONT_PATH, size=20)
32
-
33
  def recognize_text(image: Image.Image) -> str:
34
  """
35
  Function to recognize text from an image.
36
  """
 
 
 
37
  # Preprocess the image
38
  input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
39
 
@@ -44,29 +40,16 @@ def recognize_text(image: Image.Image) -> str:
44
  # Decode predictions
45
  recognized_texts = decode_predictions(log_probs)
46
 
 
47
  return recognized_texts[0]
48
 
49
- def recognize_and_overlay(image: Image.Image) -> Image.Image:
50
- """
51
- Function to recognize text and overlay it on the image.
52
- """
53
- recognized_text = recognize_text(image)
54
-
55
- # Overlay text on the image
56
- draw = ImageDraw.Draw(image)
57
- text_position = (10, 10) # Top-left corner
58
- text_color = (255, 0, 0) # Red color
59
- draw.text(text_position, f"Recognized: {recognized_text}", font=pil_font, fill=text_color)
60
-
61
- return image
62
-
63
  # Define Gradio Interface
64
  iface = gr.Interface(
65
- fn=recognize_and_overlay,
66
  inputs=gr.Image(type="pil", label="Upload Image"),
67
- outputs=gr.Image(type="pil", label="Image with Recognized Text"),
68
  title="Amharic Text Recognition",
69
- description="Upload an image containing Amharic text. The app will recognize and overlay the text on the image."
70
  )
71
 
72
  # Launch the Gradio app
 
2
 
3
  import gradio as gr
4
  import torch
5
+ from PIL import Image
6
  from model import load_model
7
  from utils import preprocess_image, decode_predictions
8
  import os
9
 
10
  # Load the model (ensure the path is correct)
11
  MODEL_PATH = "finetuned_recog_model.pth"
12
+ FONT_PATH = "NotoSansEthiopic-Regular.ttf" # Path to your font
13
 
14
  # Check if model file exists
15
  if not os.path.exists(MODEL_PATH):
16
  raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
17
 
18
+ # Check if font file exists (if you plan to use it for any visualization)
19
  if not os.path.exists(FONT_PATH):
20
  raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
21
 
 
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
  model = load_model(MODEL_PATH, device=device)
25
 
 
 
 
 
 
 
 
26
  def recognize_text(image: Image.Image) -> str:
27
  """
28
  Function to recognize text from an image.
29
  """
30
+ if image is None:
31
+ return "No image provided."
32
+
33
  # Preprocess the image
34
  input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
35
 
 
40
  # Decode predictions
41
  recognized_texts = decode_predictions(log_probs)
42
 
43
+ # Assuming batch size of 1
44
  return recognized_texts[0]
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Define Gradio Interface
47
  iface = gr.Interface(
48
+ fn=recognize_text,
49
  inputs=gr.Image(type="pil", label="Upload Image"),
50
+ outputs=gr.Textbox(label="Recognized Amharic Text"),
51
  title="Amharic Text Recognition",
52
+ description="Upload an image containing Amharic text, and the model will recognize and display the text."
53
  )
54
 
55
  # Launch the Gradio app