jasonwuyl92
initial commit after cleanup
2ab45c8
from sentence_transformers import SentenceTransformer, util as st_util
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import requests
import os
import torch
torch.set_printoptions(precision=10)
from tqdm import tqdm
import s3fs
from io import BytesIO
import vector_db
"sentence-transformer-clip-ViT-L-14"
"openai-clip"
model_names = ["fashion"]
model_name_to_ids = {
"sentence-transformer-clip-ViT-L-14": "clip-ViT-L-14",
"fashion": "patrickjohncyh/fashion-clip",
"openai-clip": "openai/clip-vit-base-patch32",
}
AWS_ACCESS_KEY_ID = os.environ["AWS_ACCESS_KEY_ID"]
AWS_SECRET_ACCESS_KEY = os.environ["AWS_SECRET_ACCESS_KEY"]
# Define your bucket and dataset name.
S3_BUCKET = "s3://disco-io"
fs = s3fs.S3FileSystem(
key=AWS_ACCESS_KEY_ID,
secret=AWS_SECRET_ACCESS_KEY,
)
ROOT_DATA_PATH = os.path.join(S3_BUCKET, 'data')
def get_data_path():
return os.path.join(ROOT_DATA_PATH, cur_dataset)
def get_image_path():
return os.path.join(get_data_path(), 'images')
def get_metadata_path():
return os.path.join(get_data_path(), 'metadata')
def get_embeddings_path():
return os.path.join(get_metadata_path(), cur_dataset + '_embeddings.pq')
model_dict = dict()
def download_to_s3(url, s3_path):
# Download the file from the URL
response = requests.get(url, stream=True)
response.raise_for_status()
# Upload the file to the S3 path
with fs.open(s3_path, "wb") as s3_file:
for chunk in response.iter_content(chunk_size=8192):
s3_file.write(chunk)
def remove_all_files_from_s3_directory(s3_directory):
# List all objects in the S3 directory
objects = fs.ls(s3_directory)
# Remove each object
for obj in objects:
try:
fs.rm(obj)
except:
print('Error removing file: ' + obj)
def download_images(df, img_folder):
remove_all_files_from_s3_directory(img_folder)
for index, row in df.iterrows():
try:
download_to_s3(row['IMG_URL'], os.path.join(img_folder,
row['title'].replace('/', '_').replace('\n', '') + '.jpg'))
except:
print('Error downloading image: ' + str(index) + row['title'])
def load_models():
for model_name in model_name_to_ids:
if model_name not in model_dict:
model_dict[model_name] = dict()
if model_name.startswith('sentence-transformer'):
model_dict[model_name]['model'] = SentenceTransformer(model_name_to_ids[model_name])
else:
model_dict[model_name]['hf_dir'] = model_name_to_ids[model_name]
model_dict[model_name]['model'] = CLIPModel.from_pretrained(model_name_to_ids[model_name])
model_dict[model_name]['processor'] = CLIPProcessor.from_pretrained(model_name_to_ids[model_name])
if len(model_dict) == 0:
print('Loading models...')
load_models()
def get_image_embedding(model_name, image):
"""
Takes an image as input and returns an embedding vector.
"""
model = model_dict[model_name]['model']
if model_name.startswith('sentence-transformer'):
return model.encode(image)
else:
inputs = model_dict[model_name]['processor'](images=image, return_tensors="pt")
image_features = model.get_image_features(**inputs).detach().numpy()[0]
return image_features
def s3_path_to_image(fs, s3_path):
"""
Takes an S3 path as input and returns a PIL Image object.
Args:
s3_path (str): The path to the image in the S3 bucket, including the bucket name (e.g., "bucket_name/path/to/image.jpg").
Returns:
Image: A PIL Image object.
"""
with fs.open(s3_path, "rb") as f:
image_data = BytesIO(f.read())
img = Image.open(image_data)
return img
def generate_and_save_embeddings():
# Get image embeddings
with torch.no_grad():
for fp in tqdm(fs.ls(get_image_path()), desc="Generate embeddings for Images"):
if fp.endswith('.jpg'):
name = fp.split('/')[-1]
for model_name in model_name_to_ids.keys():
s3_path = 's3://' + fp
vector_db.add_image_embedding_to_db(
embedding=get_image_embedding(model_name, s3_path_to_image(fs, s3_path)),
model_name=model_name,
dataset_name=cur_dataset,
path_to_image=s3_path,
image_name=name,
)
def get_immediate_subdirectories(s3_path):
return [obj.split('/')[-1] for obj in fs.glob(f"{s3_path}/*") if fs.isdir(obj)]
all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
cur_dataset = all_datasets[0]
def set_cur_dataset(dataset):
refresh_all_datasets()
print(f"Setting current dataset to {dataset}")
global cur_dataset
cur_dataset = dataset
def refresh_all_datasets():
global all_datasets
all_datasets = get_immediate_subdirectories(ROOT_DATA_PATH)
print(f"Refreshing all datasets: {all_datasets}")
def url_to_image(url):
try:
response = requests.get(url)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
return img
except requests.exceptions.RequestException as e:
print(f"Error fetching image from URL: {url}")
return None