from sentence_transformers import SentenceTransformer, util as st_util from transformers import CLIPModel, CLIPProcessor from PIL import Image import requests import os import torch torch.set_printoptions(precision=10) from tqdm import tqdm import s3fs from io import BytesIO import vector_db "sentence-transformer-clip-ViT-L-14" "openai-clip" model_names = ["fashion"] model_name_to_ids = { "sentence-transformer-clip-ViT-L-14": "clip-ViT-L-14", "fashion": "patrickjohncyh/fashion-clip", "openai-clip": "openai/clip-vit-base-patch32", } AWS_ACCESS_KEY_ID = os.environ["AWS_ACCESS_KEY_ID"] AWS_SECRET_ACCESS_KEY = os.environ["AWS_SECRET_ACCESS_KEY"] # Define your bucket and dataset name. S3_BUCKET = "s3://disco-io" fs = s3fs.S3FileSystem( key=AWS_ACCESS_KEY_ID, secret=AWS_SECRET_ACCESS_KEY, ) ROOT_DATA_PATH = os.path.join(S3_BUCKET, 'data') def get_data_path(): return os.path.join(ROOT_DATA_PATH, cur_dataset) def get_image_path(): return os.path.join(get_data_path(), 'images') def get_metadata_path(): return os.path.join(get_data_path(), 'metadata') def get_embeddings_path(): return os.path.join(get_metadata_path(), cur_dataset + '_embeddings.pq') model_dict = dict() def download_to_s3(url, s3_path): # Download the file from the URL response = requests.get(url, stream=True) response.raise_for_status() # Upload the file to the S3 path with fs.open(s3_path, "wb") as s3_file: for chunk in response.iter_content(chunk_size=8192): s3_file.write(chunk) def remove_all_files_from_s3_directory(s3_directory): # List all objects in the S3 directory objects = fs.ls(s3_directory) # Remove each object for obj in objects: try: fs.rm(obj) except: print('Error removing file: ' + obj) def download_images(df, img_folder): remove_all_files_from_s3_directory(img_folder) for index, row in df.iterrows(): try: download_to_s3(row['IMG_URL'], os.path.join(img_folder, row['title'].replace('/', '_').replace('\n', '') + '.jpg')) except: print('Error downloading image: ' + str(index) + row['title']) def load_models(): for model_name in model_name_to_ids: if model_name not in model_dict: model_dict[model_name] = dict() if model_name.startswith('sentence-transformer'): model_dict[model_name]['model'] = SentenceTransformer(model_name_to_ids[model_name]) else: model_dict[model_name]['hf_dir'] = model_name_to_ids[model_name] model_dict[model_name]['model'] = CLIPModel.from_pretrained(model_name_to_ids[model_name]) model_dict[model_name]['processor'] = CLIPProcessor.from_pretrained(model_name_to_ids[model_name]) if len(model_dict) == 0: print('Loading models...') load_models() def get_image_embedding(model_name, image): """ Takes an image as input and returns an embedding vector. """ model = model_dict[model_name]['model'] if model_name.startswith('sentence-transformer'): return model.encode(image) else: inputs = model_dict[model_name]['processor'](images=image, return_tensors="pt") image_features = model.get_image_features(**inputs).detach().numpy()[0] return image_features def s3_path_to_image(fs, s3_path): """ Takes an S3 path as input and returns a PIL Image object. Args: s3_path (str): The path to the image in the S3 bucket, including the bucket name (e.g., "bucket_name/path/to/image.jpg"). Returns: Image: A PIL Image object. """ with fs.open(s3_path, "rb") as f: image_data = BytesIO(f.read()) img = Image.open(image_data) return img def generate_and_save_embeddings(): # Get image embeddings with torch.no_grad(): for fp in tqdm(fs.ls(get_image_path()), desc="Generate embeddings for Images"): if fp.endswith('.jpg'): name = fp.split('/')[-1] for model_name in model_name_to_ids.keys(): s3_path = 's3://' + fp vector_db.add_image_embedding_to_db( embedding=get_image_embedding(model_name, s3_path_to_image(fs, s3_path)), model_name=model_name, dataset_name=cur_dataset, path_to_image=s3_path, image_name=name, ) def get_immediate_subdirectories(s3_path): return [obj.split('/')[-1] for obj in fs.glob(f"{s3_path}/*") if fs.isdir(obj)] all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH) cur_dataset = all_datasets[0] def set_cur_dataset(dataset): refresh_all_datasets() print(f"Setting current dataset to {dataset}") global cur_dataset cur_dataset = dataset def refresh_all_datasets(): global all_datasets all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH) print(f"Refreshing all datasets: {all_datasets}") def url_to_image(url): try: response = requests.get(url) response.raise_for_status() img = Image.open(BytesIO(response.content)) return img except requests.exceptions.RequestException as e: print(f"Error fetching image from URL: {url}") return None