yahiab commited on
Commit
471d95f
·
1 Parent(s): fb8456d
Files changed (2) hide show
  1. app _bk.py +111 -0
  2. app.py +99 -64
app _bk.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
7
+
8
+ # Define all available models
9
+ MODEL_LIST = {
10
+ 'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
11
+ 'vit': "google/vit-base-patch16-224",
12
+ 'convnext': "facebook/convnext-tiny-224",
13
+ }
14
+
15
+ # Global variables
16
+ current_model = None
17
+ current_preprocessor = None
18
+ device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device
19
+
20
+ # Load model and preprocessor
21
+ def load_model_and_preprocessor(model_name):
22
+ """Load model and preprocessor for a given model name."""
23
+ global current_model, current_preprocessor
24
+ print(f"Loading model and preprocessor for: {model_name} on {device}")
25
+ current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
26
+ current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
27
+ return f"Model {model_name} loaded successfully on {device}."
28
+
29
+ # Predict function
30
+ def predict(image, model, preprocessor):
31
+ """Make a prediction on the given image patch using the loaded model."""
32
+ if model is None or preprocessor is None:
33
+ raise ValueError("Model and preprocessor are not loaded.")
34
+ inputs = preprocessor(images=image, return_tensors="pt").to(device)
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
38
+ return model.config.id2label[predicted_class]
39
+
40
+ # Function to draw a rectangle on the image
41
+ def draw_rectangle(image, x, y, size=224):
42
+ """Draw a rectangle on the image."""
43
+ image_pil = image.copy() # Create a copy to avoid modifying the original image
44
+ draw = ImageDraw.Draw(image_pil)
45
+ x1, y1 = x, y
46
+ x2, y2 = x + size, y + size
47
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
48
+ return image_pil
49
+
50
+ # Function to crop the image
51
+ def crop_image(image, x, y, size=224):
52
+ """Crop a region from the image."""
53
+ image_np = np.array(image)
54
+ h, w, _ = image_np.shape
55
+ x = min(max(x, 0), w - size)
56
+ y = min(max(y, 0), h - size)
57
+ cropped = image_np[y:y+size, x:x+size]
58
+ return Image.fromarray(cropped)
59
+
60
+ # Gradio Interface
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("## Test Public Models for Coral Classification")
63
+
64
+ with gr.Row():
65
+ with gr.Column():
66
+ model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
67
+ image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
68
+ x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
69
+ y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
70
+ with gr.Column():
71
+ interactive_image = gr.Image(label="Interactive Image with Selection")
72
+ cropped_image = gr.Image(label="Cropped Patch")
73
+ label_output = gr.Textbox(label="Predicted Label")
74
+
75
+ # Update the model and preprocessor
76
+ def update_model(model_name):
77
+ return load_model_and_preprocessor(model_name)
78
+
79
+ # Update the rectangle and crop the patch
80
+ def update_selection(image, x, y):
81
+ overlay_image = draw_rectangle(image, x, y)
82
+ cropped = crop_image(image, x, y)
83
+ return overlay_image, cropped
84
+
85
+ # Predict the label from the cropped patch
86
+ def predict_from_cropped(cropped):
87
+ print(f"Type of cropped_image before prediction: {type(cropped)}")
88
+ return predict(cropped, current_model, current_preprocessor)
89
+
90
+ # Buttons and interactions
91
+ crop_button = gr.Button("Crop")
92
+ crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
93
+
94
+ predict_button = gr.Button("Predict")
95
+ predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
96
+
97
+ model_selector.change(fn=update_model, inputs=model_selector, outputs=None)
98
+
99
+ # Update sliders dynamically based on uploaded image size
100
+ def update_sliders(image):
101
+ if image is not None:
102
+ width, height = image.size
103
+ return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
104
+ return gr.update(), gr.update()
105
+
106
+ image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
107
+
108
+ # Initialize model on app start
109
+ demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None)
110
+
111
+ demo.launch(server_name="0.0.0.0", server_port=7860)
app.py CHANGED
@@ -2,54 +2,104 @@ import gradio as gr
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
  import torch
5
- from torchvision import transforms
6
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
7
-
8
- # Define all available models
9
- MODEL_LIST = {
10
- 'beit': "microsoft/beit-base-patch16-224-pt22k-ft22k",
11
- 'vit': "google/vit-base-patch16-224",
12
- 'convnext': "facebook/convnext-tiny-224",
13
- }
14
-
15
- # Global variables
16
- current_model = None
17
- current_preprocessor = None
18
- device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device
19
-
20
- # Load model and preprocessor
21
- def load_model_and_preprocessor(model_name):
22
- """Load model and preprocessor for a given model name."""
23
- global current_model, current_preprocessor
24
- print(f"Loading model and preprocessor for: {model_name} on {device}")
25
- current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
26
- current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
27
- return f"Model {model_name} loaded successfully on {device}."
28
-
29
- # Predict function
30
- def predict(image, model, preprocessor):
31
- """Make a prediction on the given image patch using the loaded model."""
32
- if model is None or preprocessor is None:
33
- raise ValueError("Model and preprocessor are not loaded.")
34
- inputs = preprocessor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  with torch.no_grad():
36
- outputs = model(**inputs)
37
- predicted_class = torch.argmax(outputs.logits, dim=1).item()
38
- return model.config.id2label[predicted_class]
 
 
39
 
40
  # Function to draw a rectangle on the image
41
  def draw_rectangle(image, x, y, size=224):
42
- """Draw a rectangle on the image."""
43
- image_pil = image.copy() # Create a copy to avoid modifying the original image
44
  draw = ImageDraw.Draw(image_pil)
45
- x1, y1 = x, y
46
- x2, y2 = x + size, y + size
47
- draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
48
  return image_pil
49
 
50
- # Function to crop the image
51
  def crop_image(image, x, y, size=224):
52
- """Crop a region from the image."""
53
  image_np = np.array(image)
54
  h, w, _ = image_np.shape
55
  x = min(max(x, 0), w - size)
@@ -57,55 +107,40 @@ def crop_image(image, x, y, size=224):
57
  cropped = image_np[y:y+size, x:x+size]
58
  return Image.fromarray(cropped)
59
 
60
- # Gradio Interface
61
  with gr.Blocks() as demo:
62
- gr.Markdown("## Test Public Models for Coral Classification")
63
-
64
  with gr.Row():
65
  with gr.Column():
66
- model_selector = gr.Dropdown(choices=list(MODEL_LIST.keys()), value='beit', label="Select Model")
67
  image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
68
- x_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="X Coordinate")
69
- y_slider = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Y Coordinate")
70
  with gr.Column():
71
- interactive_image = gr.Image(label="Interactive Image with Selection")
72
  cropped_image = gr.Image(label="Cropped Patch")
73
  label_output = gr.Textbox(label="Predicted Label")
74
-
75
- # Update the model and preprocessor
76
- def update_model(model_name):
77
- return load_model_and_preprocessor(model_name)
78
-
79
- # Update the rectangle and crop the patch
80
  def update_selection(image, x, y):
81
  overlay_image = draw_rectangle(image, x, y)
82
  cropped = crop_image(image, x, y)
83
  return overlay_image, cropped
84
 
85
- # Predict the label from the cropped patch
86
  def predict_from_cropped(cropped):
87
- print(f"Type of cropped_image before prediction: {type(cropped)}")
88
- return predict(cropped, current_model, current_preprocessor)
89
 
90
- # Buttons and interactions
91
  crop_button = gr.Button("Crop")
92
  crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
93
 
94
  predict_button = gr.Button("Predict")
95
  predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
96
 
97
- model_selector.change(fn=update_model, inputs=model_selector, outputs=None)
98
-
99
- # Update sliders dynamically based on uploaded image size
100
  def update_sliders(image):
101
- if image is not None:
102
  width, height = image.size
103
  return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
104
  return gr.update(), gr.update()
105
 
106
  image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
107
 
108
- # Initialize model on app start
109
- demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None)
110
-
111
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
  import torch
5
+ import torchvision.transforms as transforms
6
+ import timm
7
+
8
+ # URL for the Hugging Face checkpoint
9
+ CHECKPOINT_URL = "https://huggingface.co/ReefNet/beit_global/resolve/main/checkpoint-60.pth"
10
+
11
+ # Class labels
12
+ all_classes = [
13
+ 'Acanthastrea', 'Acropora', 'Agaricia', 'Alveopora', 'Astrea', 'Astreopora',
14
+ 'Caulastraea', 'Coeloseris', 'Colpophyllia', 'Coscinaraea', 'Ctenactis',
15
+ 'Cycloseris', 'Cyphastrea', 'Dendrogyra', 'Dichocoenia', 'Diploastrea',
16
+ 'Diploria', 'Dipsastraea', 'Echinophyllia', 'Echinopora', 'Euphyllia',
17
+ 'Eusmilia', 'Favia', 'Favites', 'Fungia', 'Galaxea', 'Gardineroseris',
18
+ 'Goniastrea', 'Goniopora', 'Halomitra', 'Herpolitha', 'Hydnophora',
19
+ 'Isophyllia', 'Isopora', 'Leptastrea', 'Leptoria', 'Leptoseris',
20
+ 'Lithophyllon', 'Lobactis', 'Lobophyllia', 'Madracis', 'Meandrina', 'Merulina',
21
+ 'Montastraea', 'Montipora', 'Mussa', 'Mussismilia', 'Mycedium', 'Orbicella',
22
+ 'Oulastrea', 'Oulophyllia', 'Oxypora', 'Pachyseris', 'Pavona', 'Pectinia',
23
+ 'Physogyra', 'Platygyra', 'Plerogyra', 'Plesiastrea', 'Pocillopora',
24
+ 'Podabacia', 'Porites', 'Psammocora', 'Pseudodiploria', 'Sandalolitha',
25
+ 'Scolymia', 'Seriatopora', 'Siderastrea', 'Stephanocoenia', 'Stylocoeniella',
26
+ 'Stylophora', 'Tubastraea', 'Turbinaria'
27
+ ]
28
+
29
+ # Function to load the BeIT model
30
+ def load_model(model_name):
31
+ print(f"Loading {model_name} model...")
32
+ if model_name == 'beit':
33
+ args = type('', (), {})()
34
+ args.model = 'beitv2_large_patch16_224.in1k_ft_in22k_in1k'
35
+ args.nb_classes = len(all_classes)
36
+ args.drop_path = 0.1
37
+
38
+ # Create model
39
+ model = timm.create_model(
40
+ args.model,
41
+ pretrained=False,
42
+ num_classes=args.nb_classes,
43
+ drop_path_rate=args.drop_path,
44
+ use_rel_pos_bias=True,
45
+ use_abs_pos_emb=True,
46
+ )
47
+
48
+ # Load checkpoint from Hugging Face
49
+ checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location="cpu")
50
+ state_dict = checkpoint.get('model', checkpoint)
51
+
52
+ # Filter state dict
53
+ filtered_state_dict = {k: v for k, v in state_dict.items() if "relative_position_index" not in k}
54
+ model.load_state_dict(filtered_state_dict, strict=False)
55
+ else:
56
+ raise ValueError(f"Model {model_name} not implemented!")
57
+
58
+ # Move model to CUDA if available
59
+ model.eval()
60
+ if torch.cuda.is_available():
61
+ model.cuda()
62
+ return model
63
+
64
+ # Preprocessing transforms
65
+ preprocess = transforms.Compose([
66
+ transforms.Resize((224, 224)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
69
+ ])
70
+
71
+ # Initialize selected model
72
+ selected_model_name = 'beit'
73
+ model = load_model(selected_model_name)
74
+
75
+ def predict_label(image):
76
+ """Predict the label for the given image."""
77
+ # Ensure the image is a PIL Image
78
+ if isinstance(image, np.ndarray):
79
+ image = Image.fromarray(image)
80
+ elif not isinstance(image, Image.Image):
81
+ raise TypeError(f"Unexpected type {type(image)}, expected PIL.Image or numpy.ndarray.")
82
+
83
+ input_tensor = preprocess(image).unsqueeze(0)
84
+ if torch.cuda.is_available():
85
+ input_tensor = input_tensor.cuda()
86
+
87
  with torch.no_grad():
88
+ outputs = model(input_tensor)
89
+ predicted_class = torch.argmax(outputs, dim=1).item()
90
+
91
+ return all_classes[predicted_class]
92
+
93
 
94
  # Function to draw a rectangle on the image
95
  def draw_rectangle(image, x, y, size=224):
96
+ image_pil = image.copy()
 
97
  draw = ImageDraw.Draw(image_pil)
98
+ draw.rectangle([x, y, x + size, y + size], outline="red", width=3)
 
 
99
  return image_pil
100
 
101
+ # Crop a region of interest
102
  def crop_image(image, x, y, size=224):
 
103
  image_np = np.array(image)
104
  h, w, _ = image_np.shape
105
  x = min(max(x, 0), w - size)
 
107
  cropped = image_np[y:y+size, x:x+size]
108
  return Image.fromarray(cropped)
109
 
110
+ # Gradio UI
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("## Coral Classification with BeIT Model")
 
113
  with gr.Row():
114
  with gr.Column():
 
115
  image_input = gr.Image(type="pil", label="Upload Image", interactive=True)
116
+ x_slider = gr.Slider(0, 1000, step=1, value=0, label="X Coordinate")
117
+ y_slider = gr.Slider(0, 1000, step=1, value=0, label="Y Coordinate")
118
  with gr.Column():
119
+ interactive_image = gr.Image(label="Interactive Image")
120
  cropped_image = gr.Image(label="Cropped Patch")
121
  label_output = gr.Textbox(label="Predicted Label")
122
+
123
+ # Interactions
 
 
 
 
124
  def update_selection(image, x, y):
125
  overlay_image = draw_rectangle(image, x, y)
126
  cropped = crop_image(image, x, y)
127
  return overlay_image, cropped
128
 
 
129
  def predict_from_cropped(cropped):
130
+ return predict_label(cropped)
 
131
 
 
132
  crop_button = gr.Button("Crop")
133
  crop_button.click(fn=update_selection, inputs=[image_input, x_slider, y_slider], outputs=[interactive_image, cropped_image])
134
 
135
  predict_button = gr.Button("Predict")
136
  predict_button.click(fn=predict_from_cropped, inputs=cropped_image, outputs=label_output)
137
 
 
 
 
138
  def update_sliders(image):
139
+ if image:
140
  width, height = image.size
141
  return gr.update(maximum=width - 224), gr.update(maximum=height - 224)
142
  return gr.update(), gr.update()
143
 
144
  image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
145
 
 
 
 
146
  demo.launch(server_name="0.0.0.0", server_port=7860)