Spaces:
Runtime error
Runtime error
Make device dependable of the machine capacity
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import random
|
|
3 |
from datasets import load_dataset
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
import logging
|
|
|
6 |
from PIL import Image
|
7 |
# Create a custom logger
|
8 |
logger = logging.getLogger(__name__)
|
@@ -22,9 +23,10 @@ c_handler.setFormatter(c_format)
|
|
22 |
logger.addHandler(c_handler)
|
23 |
|
24 |
class SearchEngine:
|
25 |
-
def __init__(self):
|
|
|
26 |
self.model = SentenceTransformer('clip-ViT-B-32')
|
27 |
-
self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device=
|
28 |
image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
|
29 |
self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
|
30 |
|
@@ -35,12 +37,12 @@ class SearchEngine:
|
|
35 |
|
36 |
def search_images_from_text(self, text):
|
37 |
logger.info("Searching images from text")
|
38 |
-
emb = self.model.encode(text, convert_to_tensor=True, device=
|
39 |
return self.get_candidates(query_embedding=emb)
|
40 |
|
41 |
def search_images_from_image(self, image):
|
42 |
logger.info("Searching images from image")
|
43 |
-
emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device=
|
44 |
return self.get_candidates(query_embedding=emb)
|
45 |
|
46 |
def main():
|
|
|
3 |
from datasets import load_dataset
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
import logging
|
6 |
+
import torch
|
7 |
from PIL import Image
|
8 |
# Create a custom logger
|
9 |
logger = logging.getLogger(__name__)
|
|
|
23 |
logger.addHandler(c_handler)
|
24 |
|
25 |
class SearchEngine:
|
26 |
+
def __init__(self, device="cpu"):
|
27 |
+
self.device = device if torch.cuda.is_available() else "cpu"
|
28 |
self.model = SentenceTransformer('clip-ViT-B-32')
|
29 |
+
self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device=self.device)
|
30 |
image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
|
31 |
self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
|
32 |
|
|
|
37 |
|
38 |
def search_images_from_text(self, text):
|
39 |
logger.info("Searching images from text")
|
40 |
+
emb = self.model.encode(text, convert_to_tensor=True, device=self.device)
|
41 |
return self.get_candidates(query_embedding=emb)
|
42 |
|
43 |
def search_images_from_image(self, image):
|
44 |
logger.info("Searching images from image")
|
45 |
+
emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device=self.device)
|
46 |
return self.get_candidates(query_embedding=emb)
|
47 |
|
48 |
def main():
|