HebaAllah commited on
Commit
021b464
·
verified ·
1 Parent(s): 3f59239

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Search models, datasets, users...
3
+
4
+ Spaces:
5
+
6
+ NemesisAlm
7
+ /
8
+ clip-satellite-demo
9
+
10
+ like
11
+ 1
12
+ App
13
+ Files
14
+ Community
15
+ clip-satellite-demo
16
+ /
17
+ app.py
18
+
19
+ NemesisAlm's picture
20
+ NemesisAlm
21
+ 1st commit
22
+ 0b0d380
23
+ raw
24
+
25
+ Copy download link
26
+ history
27
+ blame
28
+ contribute
29
+ delete
30
+
31
+ 4.14 kB
32
+ import gradio as gr
33
+
34
+ import torch
35
+ from PIL import Image
36
+ from transformers import CLIPModel, CLIPProcessor
37
+
38
+ LIST_LABELS = ['agricultural land', 'airplane', 'baseball diamond', 'beach', 'buildings', 'chaparral', 'dense residential area', 'forest', 'freeway', 'golf course', 'harbor', 'intersection', 'medium residential area', 'mobilehome park', 'overpass', 'parking lot', 'river', 'runway', 'sparse residential area', 'storage tanks', 'tennis court']
39
+
40
+ CLIP_LABELS = [f"A satellite image of {label}" for label in LIST_LABELS]
41
+
42
+ MODEL_NAME = "NemesisAlm/clip-fine-tuned-satellite"
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
46
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
47
+
48
+ fine_tuned_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
49
+ fine_tuned_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
50
+
51
+
52
+ def classify(image_path, model_number):
53
+ if model_number == "CLIP":
54
+ processor = clip_processor
55
+ model = clip_model
56
+ else:
57
+ processor = fine_tuned_processor
58
+ model = fine_tuned_model
59
+ image = Image.open(image_path).convert('RGB')
60
+ inputs = processor(text=CLIP_LABELS, images=image, return_tensors="pt", padding=True).to(device)
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+ logits_per_image = outputs.logits_per_image
64
+ prediction = logits_per_image.softmax(dim=1)
65
+ confidences = {LIST_LABELS[i]: float(prediction[0][i].item()) for i in range(len(LIST_LABELS))}
66
+ return confidences
67
+
68
+ DESCRIPTION="""
69
+ <div style="font-family: Arial, sans-serif; line-height: 1.6; margin: auto; text-align: center;">
70
+ <h2 style="color: #333;">CLIP Fine-Tuned Satellite Model Demo</h2>
71
+ <p>
72
+ This space demonstrates the capabilities of a <strong>fine-tuned CLIP-based model</strong>
73
+ in classifying satellite images. The model has been specifically trained on the
74
+ <em>UC Merced</em> satellite image dataset.
75
+ </p>
76
+ <p>
77
+ After just <strong>2 epochs of training</strong>, adjusting only 30% of the model parameters,
78
+ the model's accuracy in classifying satellite images has significantly improved, from an
79
+ initial accuracy of <strong>58.8%</strong> to <strong>96.9%</strong> on the test set.
80
+ </p>
81
+ <p>
82
+ Explore this space to see its performance and compare it with the initial CLIP model.
83
+ </p>
84
+ </div>
85
+ """
86
+
87
+ FOOTER = """
88
+ <div style="margin-top:50px">
89
+ Link to model: <a href='https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite'>https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite</a><br>
90
+ Link to dataset: <a href='https://huggingface.co/datasets/blanchon/UC_Merced'>https://huggingface.co/datasets/blanchon/UC_Merced</a>
91
+ </div>
92
+ """
93
+
94
+ with gr.Blocks(title="Satellite image classification", css="") as demo:
95
+ logo = gr.HTML("<img src='file/logo_gradio.png' style='margin:auto'/>")
96
+ description = gr.HTML(DESCRIPTION)
97
+ with gr.Row():
98
+ with gr.Column():
99
+ input_image = gr.Image(type='filepath', label='Input image')
100
+ submit_btn = gr.Button("Submit", variant="primary")
101
+ with gr.Column():
102
+ title_1 = gr.HTML("<h1 style='text-align:center'>Original CLIP Model</h1>")
103
+ model_1 = gr.Textbox("CLIP", visible=False)
104
+ output_labels_clip = gr.Label(num_top_classes=10, label="Top 10 classes")
105
+ with gr.Column():
106
+ title_2 = gr.HTML("<h1 style='text-align:center'>Fine-tuned Model</h1>")
107
+ model_2 = gr.Textbox("Fine-tuned", visible=False)
108
+ output_labels_finetuned = gr.Label(num_top_classes=10, label="Top 10 classes")
109
+ examples = gr.Examples([["0.jpg"], ["1.jpg"], ["2.jpg"], ["3.jpg"] ], input_image)
110
+ footer = gr.HTML(FOOTER)
111
+ submit_btn.click(fn=classify, inputs=[input_image, model_1], outputs=output_labels_clip).then( classify, inputs=[input_image, model_2], outputs=[output_labels_finetuned] )
112
+
113
+
114
+ demo.queue()
115
+ demo.launch(server_name="0.0.0.0",favicon_path='favicon.ico', allowed_paths=["logo_gradio.png", "0.jpg", "1.jpg", "2.jpg", "3.jpg"])
116
+
117
+
118
+
119
+