JLD commited on
Commit
446f144
1 Parent(s): 87e70e9

Make device dependable of the machine capacity

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer, util
5
  import logging
 
6
  from PIL import Image
7
  # Create a custom logger
8
  logger = logging.getLogger(__name__)
@@ -22,9 +23,10 @@ c_handler.setFormatter(c_format)
22
  logger.addHandler(c_handler)
23
 
24
  class SearchEngine:
25
- def __init__(self):
 
26
  self.model = SentenceTransformer('clip-ViT-B-32')
27
- self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device="cuda:0")
28
  image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
29
  self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
30
 
@@ -35,12 +37,12 @@ class SearchEngine:
35
 
36
  def search_images_from_text(self, text):
37
  logger.info("Searching images from text")
38
- emb = self.model.encode(text, convert_to_tensor=True, device="cuda:0")
39
  return self.get_candidates(query_embedding=emb)
40
 
41
  def search_images_from_image(self, image):
42
  logger.info("Searching images from image")
43
- emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device="cuda:0")
44
  return self.get_candidates(query_embedding=emb)
45
 
46
  def main():
 
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer, util
5
  import logging
6
+ import torch
7
  from PIL import Image
8
  # Create a custom logger
9
  logger = logging.getLogger(__name__)
 
23
  logger.addHandler(c_handler)
24
 
25
  class SearchEngine:
26
+ def __init__(self, device="cpu"):
27
+ self.device = device if torch.cuda.is_available() else "cpu"
28
  self.model = SentenceTransformer('clip-ViT-B-32')
29
+ self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device=self.device)
30
  image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
31
  self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
32
 
 
37
 
38
  def search_images_from_text(self, text):
39
  logger.info("Searching images from text")
40
+ emb = self.model.encode(text, convert_to_tensor=True, device=self.device)
41
  return self.get_candidates(query_embedding=emb)
42
 
43
  def search_images_from_image(self, image):
44
  logger.info("Searching images from image")
45
+ emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device=self.device)
46
  return self.get_candidates(query_embedding=emb)
47
 
48
  def main():