pharmapsychotic commited on
Commit
53924ae
1 Parent(s): 9db4527
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -57,7 +57,7 @@ class LabelTable():
57
  self.embeds = []
58
  chunks = np.array_split(self.labels, max(1, len(self.labels)/chunk_size))
59
  for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
60
- text_tokens = clip.tokenize(chunk).cuda()
61
  with torch.no_grad():
62
  text_features = clip_model.encode_text(text_tokens).float()
63
  text_features /= text_features.norm(dim=-1, keepdim=True)
@@ -113,7 +113,7 @@ def load_list(filename):
113
  return items
114
 
115
  def rank_top(image_features, text_array):
116
- text_tokens = clip.tokenize([text for text in text_array]).cuda()
117
  with torch.no_grad():
118
  text_features = clip_model.encode_text(text_tokens).float()
119
  text_features /= text_features.norm(dim=-1, keepdim=True)
@@ -126,7 +126,7 @@ def rank_top(image_features, text_array):
126
  return text_array[top_labels[0][0].numpy()]
127
 
128
  def similarity(image_features, text):
129
- text_tokens = clip.tokenize([text]).cuda()
130
  with torch.no_grad():
131
  text_features = clip_model.encode_text(text_tokens).float()
132
  text_features /= text_features.norm(dim=-1, keepdim=True)
@@ -136,7 +136,7 @@ def similarity(image_features, text):
136
  def interrogate(image):
137
  caption = generate_caption(image)
138
 
139
- images = clip_preprocess(image).unsqueeze(0).cuda()
140
  with torch.no_grad():
141
  image_features = clip_model.encode_image(images).float()
142
  image_features /= image_features.norm(dim=-1, keepdim=True)
 
57
  self.embeds = []
58
  chunks = np.array_split(self.labels, max(1, len(self.labels)/chunk_size))
59
  for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
60
+ text_tokens = clip.tokenize(chunk).to(device)
61
  with torch.no_grad():
62
  text_features = clip_model.encode_text(text_tokens).float()
63
  text_features /= text_features.norm(dim=-1, keepdim=True)
 
113
  return items
114
 
115
  def rank_top(image_features, text_array):
116
+ text_tokens = clip.tokenize([text for text in text_array]).to(device)
117
  with torch.no_grad():
118
  text_features = clip_model.encode_text(text_tokens).float()
119
  text_features /= text_features.norm(dim=-1, keepdim=True)
 
126
  return text_array[top_labels[0][0].numpy()]
127
 
128
  def similarity(image_features, text):
129
+ text_tokens = clip.tokenize([text]).to(device)
130
  with torch.no_grad():
131
  text_features = clip_model.encode_text(text_tokens).float()
132
  text_features /= text_features.norm(dim=-1, keepdim=True)
 
136
  def interrogate(image):
137
  caption = generate_caption(image)
138
 
139
+ images = clip_preprocess(image).unsqueeze(0).to(device)
140
  with torch.no_grad():
141
  image_features = clip_model.encode_image(images).float()
142
  image_features /= image_features.norm(dim=-1, keepdim=True)