cfoli commited on
Commit
7a3bf69
·
verified ·
1 Parent(s): 3e94470

Rename zero_shot_classification.py to app.py

Browse files
zero_shot_classification.py → app.py RENAMED
@@ -57,7 +57,7 @@ LABELS_MAP = ["Bat (baseball)", "Bat (mammal)",
57
 
58
  """
59
 
60
- model_key = "CLIP-large"
61
 
62
  # Load model (cache for speed)
63
  if model_key not in MODEL_CACHE:
@@ -73,8 +73,8 @@ output = classifier(
73
  candidate_labels = CANDIDATE_LABELS,
74
  hypothesis_template = "This image shows {}")
75
 
76
- print("\n\n=============================================================================")
77
- print(f"\nPrediction: This image shows {output[0]["label"]} | Confidence (probability): {100*output[0]["score"]: .1f}%")
78
 
79
  def run_classifer(model_key, image_path, prob_threshold = None):
80
  # model_key: name of backbone zero-shot-image-classification model to use
@@ -107,17 +107,15 @@ def run_classifer(model_key, image_path, prob_threshold = None):
107
 
108
  return predicted_label_str, prob_dict
109
 
110
- # example run
111
- model_key = "CLIP-large"
112
- BASE_DIR = '/content/drive/MyDrive/ML Projects/Zero-shot Image Classification/Images'
113
- image_path = os.path.join(BASE_DIR, 'Nail2_1.png')
114
 
115
- predicted_label_str, prob_dict = run_classifer(model_key, image_path, prob_threshold = 0.4)
116
- print("\n\n=============================================================================")
117
- # print(f"\nPrediction: {predicted_label_str} | Confidence (probability): {100*output[0]['score']:.1f}%")
118
- print(f"\nPrediction: {predicted_label_str}")
119
-
120
- prob_dict
121
 
122
  """### Gradio App
123
 
 
57
 
58
  """
59
 
60
+ # model_key = "CLIP-large"
61
 
62
  # Load model (cache for speed)
63
  if model_key not in MODEL_CACHE:
 
73
  candidate_labels = CANDIDATE_LABELS,
74
  hypothesis_template = "This image shows {}")
75
 
76
+ # print("\n\n=============================================================================")
77
+ # print(f"\nPrediction: This image shows {output[0]["label"]} | Confidence (probability): {100*output[0]["score"]: .1f}%")
78
 
79
  def run_classifer(model_key, image_path, prob_threshold = None):
80
  # model_key: name of backbone zero-shot-image-classification model to use
 
107
 
108
  return predicted_label_str, prob_dict
109
 
110
+ # # example run
111
+ # model_key = "CLIP-large"
112
+ # BASE_DIR = '/content/drive/MyDrive/ML Projects/Zero-shot Image Classification/Images'
113
+ # image_path = os.path.join(BASE_DIR, 'Nail2_1.png')
114
 
115
+ # predicted_label_str, prob_dict = run_classifer(model_key, image_path, prob_threshold = 0.4)
116
+ # print("\n\n=============================================================================")
117
+ # # print(f"\nPrediction: {predicted_label_str} | Confidence (probability): {100*output[0]['score']:.1f}%")
118
+ # print(f"\nPrediction: {predicted_label_str}")
 
 
119
 
120
  """### Gradio App
121