ayoubkirouane commited on
Commit
e58a58b
1 Parent(s): 8e0c22f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gdown
2
+
3
+ def download_file_from_google_drive(file_id, output_file):
4
+ """
5
+ Download a file from Google Drive.
6
+
7
+ :param file_id: The Google Drive file ID.
8
+ :param output_file: The name of the file to save.
9
+ """
10
+ url = f"https://drive.google.com/uc?id={file_id}"
11
+ gdown.download(url, output_file, quiet=False)
12
+
13
+ # Example usage:
14
+ file_id = "1Wgh9dWT6SbakJhvuNkSaIa1ydFtkfUW6"
15
+ out = "average_model.pth"
16
+ download_file_from_google_drive(file_id,out)
17
+
18
+ from super_gradients.training import models
19
+ import torch
20
+ import supervision as sv
21
+ import gradio as gr
22
+
23
+ DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"
24
+ MODEL_ARCH = 'yolo_nas_l'
25
+ clasess = ["Airplane"]
26
+ checkpoint_path= "average_model.pth"
27
+
28
+
29
+ def run(image , CONFIDENCE_TRESHOLD) :
30
+ best_model = models.get(
31
+ MODEL_ARCH,
32
+ num_classes=len(clasess),
33
+ checkpoint_path= checkpoint_path
34
+ ).to(DEVICE)
35
+ result = list(best_model.predict(image, conf=CONFIDENCE_TRESHOLD))[0]
36
+ detections = sv.Detections(
37
+ xyxy=result.prediction.bboxes_xyxy,
38
+ confidence=result.prediction.confidence,
39
+ class_id=result.prediction.labels.astype(int)
40
+ )
41
+ box_annotator = sv.BoxAnnotator()
42
+ annotated_frame = box_annotator.annotate(
43
+ scene=image.copy(),
44
+ detections=detections,
45
+ labels=clasess
46
+ )
47
+ return annotated_frame
48
+ iface = gr.Interface(
49
+ fn=run,
50
+ inputs=[gr.Image(label="Input image", type="numpy") , gr.Slider(0, 1, value=0.5, label="Select your CONFIDENCE_TRESHOLD")],
51
+ outputs=gr.Image(label="The Prediction Output :", type="numpy"),
52
+ title="Aerial Airport YOLO Nas object detection",
53
+ allow_flagging=False ,
54
+ description="I conducted fine-tuning on the YOLO-NAS (YOLO Neural Architecture Search) model, a cutting-edge object detection architecture developed by Deci-AI. My objective was to enhance its ability to detect airplanes in the 'Aerial Airport' dataset",
55
+ )
56
+ iface.launch(debug=True)