mykeysid10 commited on
Commit
f6172d7
1 Parent(s): e757fca

Update cloud_coverage_pipeline.py

Browse files
Files changed (1) hide show
  1. cloud_coverage_pipeline.py +12 -6
cloud_coverage_pipeline.py CHANGED
@@ -11,6 +11,7 @@ from torch.utils.data import Dataset, DataLoader
11
  from tqdm.autonotebook import tqdm
12
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
13
 
 
14
  # Trained Model Configurations
15
  CFG = {
16
  "debug": False,
@@ -40,6 +41,7 @@ CFG = {
40
  "dropout": 0.1
41
  }
42
 
 
43
  # Loading Finetuned Clip Model to the below class format
44
  class CLIPModel(nn.Module):
45
  def __init__(
@@ -55,6 +57,7 @@ class CLIPModel(nn.Module):
55
  self.text_projection = ProjectionHead(embedding_dim = text_embedding)
56
  self.temperature = temperature
57
 
 
58
  # Image Encoder Class to extract features using finetuned clip's Resnet Image Encoder
59
  class ImageEncoder(nn.Module):
60
  def __init__(self, model_name = CFG["model_name"], pretrained = CFG["pretrained"], trainable = CFG["trainable"]):
@@ -66,6 +69,7 @@ class ImageEncoder(nn.Module):
66
  def forward(self, x):
67
  return self.model(x)
68
 
 
69
  # Text Encoder - Optional in inference
70
  class TextEncoder(nn.Module):
71
  def __init__(self, model_name = CFG["text_encoder_model"], pretrained = CFG["pretrained"], trainable = CFG["trainable"]):
@@ -85,6 +89,7 @@ class TextEncoder(nn.Module):
85
  last_hidden_state = output.last_hidden_state
86
  return last_hidden_state[:, self.target_token_idx, :]
87
 
 
88
  # Projection Class - Optional in inference
89
  class ProjectionHead(nn.Module):
90
  def __init__(
@@ -109,6 +114,7 @@ class ProjectionHead(nn.Module):
109
  x = self.layer_norm(x)
110
  return x
111
 
 
112
  # Class to transform image to custom data format
113
  class SkyImage(Dataset):
114
  def __init__(self, img, label):
@@ -124,16 +130,15 @@ class SkyImage(Dataset):
124
  label = self.img_label[idx]
125
  return image, label
126
 
127
- # Method to initialize CatBoost model
128
- def initialize_catboost_model():
129
- return pickle.load(open("catboost_model.sav", 'rb'))
130
 
131
- # Method to initialize CLIP model
132
- def initialize_clip_model():
 
133
  clip_model = CLIPModel().to(CFG["device"])
134
  clip_model.load_state_dict(torch.load("clip_model.pt", map_location = CFG["device"]))
135
  clip_model.eval()
136
- return clip_model
 
137
 
138
  # Method to extract features from finetuned clip model
139
  def get_features(clip_model, dataset):
@@ -146,6 +151,7 @@ def get_features(clip_model, dataset):
146
  label.append(labels)
147
  return torch.cat(features), torch.cat(label).cpu()
148
 
 
149
  # Method to calculate cloud coverage
150
  def predict_cloud_coverage(image, clip_model, CTBR_model):
151
  img, lbl = [image], [0]
 
11
  from tqdm.autonotebook import tqdm
12
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
13
 
14
+
15
  # Trained Model Configurations
16
  CFG = {
17
  "debug": False,
 
41
  "dropout": 0.1
42
  }
43
 
44
+
45
  # Loading Finetuned Clip Model to the below class format
46
  class CLIPModel(nn.Module):
47
  def __init__(
 
57
  self.text_projection = ProjectionHead(embedding_dim = text_embedding)
58
  self.temperature = temperature
59
 
60
+
61
  # Image Encoder Class to extract features using finetuned clip's Resnet Image Encoder
62
  class ImageEncoder(nn.Module):
63
  def __init__(self, model_name = CFG["model_name"], pretrained = CFG["pretrained"], trainable = CFG["trainable"]):
 
69
  def forward(self, x):
70
  return self.model(x)
71
 
72
+
73
  # Text Encoder - Optional in inference
74
  class TextEncoder(nn.Module):
75
  def __init__(self, model_name = CFG["text_encoder_model"], pretrained = CFG["pretrained"], trainable = CFG["trainable"]):
 
89
  last_hidden_state = output.last_hidden_state
90
  return last_hidden_state[:, self.target_token_idx, :]
91
 
92
+
93
  # Projection Class - Optional in inference
94
  class ProjectionHead(nn.Module):
95
  def __init__(
 
114
  x = self.layer_norm(x)
115
  return x
116
 
117
+
118
  # Class to transform image to custom data format
119
  class SkyImage(Dataset):
120
  def __init__(self, img, label):
 
130
  label = self.img_label[idx]
131
  return image, label
132
 
 
 
 
133
 
134
+ # Method to initialize CatBoost model
135
+ def initialize_models():
136
+ cbt_model = pickle.load(open("catboost_model.sav", 'rb'))
137
  clip_model = CLIPModel().to(CFG["device"])
138
  clip_model.load_state_dict(torch.load("clip_model.pt", map_location = CFG["device"]))
139
  clip_model.eval()
140
+ return cbt_model, clip_model
141
+
142
 
143
  # Method to extract features from finetuned clip model
144
  def get_features(clip_model, dataset):
 
151
  label.append(labels)
152
  return torch.cat(features), torch.cat(label).cpu()
153
 
154
+
155
  # Method to calculate cloud coverage
156
  def predict_cloud_coverage(image, clip_model, CTBR_model):
157
  img, lbl = [image], [0]