๐คtransformers, ๐คdatasets, FAISS๋ฅผ ์ฌ์ฉํ ๋ฉํฐ๋ชจ๋ฌ ๋ฐ์ดํฐ ์๋ฒ ๋ฉ ๋ฐ ์ ์ฌ์ฑ ๊ฒ์
์์ฑ์: Merve Noyan, ์ด์ ์ธ
์๋ฒ ๋ฉ์ ์๋ฏธ๋ก ์ ์ผ๋ก ์ค์ํ ์ ๋ณด์ ์์ถ์ ๋๋ค. ์ด๋ ์ ์ฌ์ฑ ๊ฒ์, ์ ๋ก์ท ๋ถ๋ฅ ๋๋ ์๋ก์ด ๋ชจ๋ธ์ ํ๋ จํ๋ ๋ฐ ์ฌ์ฉ๋ ์ ์์ต๋๋ค. ์ ์ฌ์ฑ ๊ฒ์์ ํ์ฉ ์ฌ๋ก๋ก๋ ์ ์์๊ฑฐ๋์์ ์ ์ฌํ ์ ํ ๊ฒ์, ์์ ๋ฏธ๋์ด์์์ ์ฝํ ์ธ ๊ฒ์ ๋ฑ์ด ์์ต๋๋ค. ์ด ๋ ธํธ๋ถ์ ๐คTransformers, ๐คDatasets ๋ฐ FAISS๋ฅผ ์ฌ์ฉํ์ฌ ํน์ง ์ถ์ถ ๋ชจ๋ธ๋ก๋ถํฐ ์๋ฒ ๋ฉ์ ์์ฑํ๊ณ ์ธ๋ฑ์ฑํ์ฌ ์ดํ ์ ์ฌ์ฑ ๊ฒ์์ ํ์ฉํ๋ ๋ฐฉ๋ฒ์ ์๋ดํฉ๋๋ค. ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํด๋ด ์๋ค.
!pip install -q datasets faiss-gpu transformers sentencepiece
์ด ํํ ๋ฆฌ์ผ์์๋ CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํน์ง์ ์ถ์ถํ ๊ฒ์ ๋๋ค. CLIP์ ํ ์คํธ ์ธ์ฝ๋์ ์ด๋ฏธ์ง ์ธ์ฝ๋๋ฅผ ํจ๊ป ํ์ต์์ผ ๋ ๊ฐ์ง ๋ชจ๋ฌ๋ฆฌํฐ๋ฅผ ์ฐ๊ฒฐํ๋ ํ์ ์ ์ธ ๋ชจ๋ธ์ ๋๋ค.
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
import faiss
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")
๋ฐ์ดํฐ์ ์ ๋ก๋ํฉ๋๋ค. ๊ฐ๋ณ๊ฒ ์ด ์์ ๋ฅผ ํด ๋ณด๊ธฐ ์ํด, ์์ ์บก์ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํด๋ด ์๋ค, jmhessel/newyorker_caption_contest.
from datasets import load_dataset
ds = load_dataset("jmhessel/newyorker_caption_contest", "explanation")
์์ ๋ฅผ ํ๋ ๋ด ์๋ค.
>>> ds["train"][0]["image"]
ds["train"][0]["image_description"]
์ฐ๋ฆฌ๋ ์์ ๋ฅผ ์๋ฒ ๋ฉํ๊ฑฐ๋ ์ธ๋ฑ์ค๋ฅผ ์์ฑํ๊ธฐ ์ํด ์ด๋ค ํจ์๋ ์์ฑํ ํ์๊ฐ ์์ต๋๋ค. ๐คDatasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ FAISS ํตํฉ์ด ์ด๋ฌํ ๊ณผ์ ์ ์ถ์ํํด์ค๋๋ค. ์๋์ ๊ฐ์ด ๋ฐ์ดํฐ์
์ map
๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ์์ ์ ๋ํ ์๋ฒ ๋ฉ์ ํฌํจํ๋ ์๋ก์ด ์ด์ ๊ฐ๋จํ๊ฒ ์์ฑํ ์ ์์ต๋๋ค. ์ด์ ํ๋กฌํํธ ์ด์์ ํ
์คํธ ํน์ง์ ์ํ ์๋ฒ ๋ฉ์ ๋ง๋ค์ด๋ด
์๋ค.
dataset = ds["train"]
ds_with_embeddings = dataset.map(
lambda example: {
"embeddings": model.get_text_features(
**tokenizer([example["image_description"]], truncation=True, return_tensors="pt").to("cuda")
)[0]
.detach()
.cpu()
.numpy()
}
)
๋์ผํ ๋ฐฉ์์ผ๋ก ์ด๋ฏธ์ง ์๋ฒ ๋ฉ๋ ์ป์ ์ ์์ต๋๋ค.
ds_with_embeddings = ds_with_embeddings.map(
lambda example: {
"image_embeddings": model.get_image_features(**processor([example["image"]], return_tensors="pt").to("cuda"))[
0
]
.detach()
.cpu()
.numpy()
}
)
์ด์ ์ฐ๋ฆฌ๋ ๊ฐ ์ด์ ๋ํ ์ธ๋ฑ์ค๋ฅผ ์ถ๊ฐํฉ๋๋ค.
# ํ
์คํธ ์๋ฒ ๋ฉ์ ์ํ FAISS ์ธ๋ฑ์ค๋ฅผ ๋ง๋ญ๋๋ค.
ds_with_embeddings.add_faiss_index(column="embeddings")
# ์ด๋ฏธ์ง ์๋ฒ ๋ฉ์ ์ํ FAISS ์ธ๋ฑ์ค๋ฅผ ๋ง๋ญ๋๋ค.
ds_with_embeddings.add_faiss_index(column="image_embeddings")
ํ ์คํธ ํ๋กฌํํธ๋ก ๋ฐ์ดํฐ ์ง๋ฌธํ๊ธฐ
์ด์ ํ ์คํธ๋ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ์ ์ง๋ฌธ์ ๋์ง๊ณ , ์ ์ฌํ ํญ๋ชฉ์ ์ป์ ์ ์์ต๋๋ค.
prmt = "a snowy day"
prmt_embedding = (
model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0]
.detach()
.cpu()
.numpy()
)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples("embeddings", prmt_embedding, k=1)
>>> def downscale_images(image):
... width = 200
... ratio = width / float(image.size[0])
... height = int((float(image.size[1]) * float(ratio)))
... img = image.resize((width, height), Image.Resampling.LANCZOS)
... return img
>>> images = [downscale_images(image) for image in retrieved_examples["image"]]
>>> # ์ ์ฌํ ํ
์คํธ์ ์ด๋ฏธ์ง๋ฅผ ํ์ธํฉ๋๋ค.
>>> print(retrieved_examples["image_description"])
>>> display(images[0])
['A man is in the snow. A boy with a huge snow shovel is there too. They are outside a house.']
์ด๋ฏธ์ง ํ๋กฌํํธ๋ก ๋ฐ์ดํฐ ์ง๋ฌธํ๊ธฐ
์ด๋ฏธ์ง ์ ์ฌ์ฑ ์ถ๋ก ๋ ๋ง์ฐฌ๊ฐ์ง๋ก, get_image_features
๋ฅผ ํธ์ถํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
>>> import requests
>>> # image of a beaver
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> display(downscale_images(image))
์ด ๋น๋ฒ ์ด๋ฏธ์ง์ ๋น์ทํ ์ด๋ฏธ์ง๋ฅผ ๊ฒ์ ํด ๋ด ์๋ค.
img_embedding = (
model.get_image_features(**processor([image], return_tensors="pt", truncation=True).to("cuda"))[0]
.detach()
.cpu()
.numpy()
)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples("image_embeddings", img_embedding, k=1)
๋น๋ฒ ์ด๋ฏธ์ง์ ๊ฐ์ฅ ๋น์ทํ ์ด๋ฏธ์ง๊ฐ ํ๋ฉด์ ํ์๋ฉ๋๋ค.
>>> images = [downscale_images(image) for image in retrieved_examples["image"]]
>>> # ์ ์ฌํ ํ
์คํธ์ ์ด๋ฏธ์ง๋ฅผ ํ์ธํฉ๋๋ค.
>>> print(retrieved_examples["image_description"])
>>> display(images[0])
['Salmon swim upstream but they see a grizzly bear and are in shock. The bear has a smug look on his face when he sees the salmon.']
์๋ฒ ๋ฉ์ ์ ์ฅํ๊ณ , ์ฌ๋ฆฌ๊ณ , ๊ฐ์ ธ์ค๊ธฐ
์๋ฒ ๋ฉ์ด ํฌํจ๋ ๋ฐ์ดํฐ์
์ save_faiss_index
๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ฅํ ์ ์์ต๋๋ค.
ds_with_embeddings.save_faiss_index("embeddings", "embeddings/embeddings.faiss")
ds_with_embeddings.save_faiss_index("image_embeddings", "embeddings/image_embeddings.faiss")
์๋ฒ ๋ฉ์ ๋ฐ์ดํฐ์
์ ์ฅ์์ ์ ์ฅํ๋ ๊ฒ์ ์ข์ ์ต๊ด์
๋๋ค. ๋ฐ๋ผ์ ์ฐ๋ฆฌ๋ Hugging Face Hub์ ๋ก๊ทธ์ธํ๊ณ , ๋ฐ์ดํฐ์
์ ์ฅ์๋ฅผ ์์ฑํ ํ, ๊ทธ๊ณณ์ ์๋ฒ ๋ฉ ์ธ๋ฑ์ค๋ฅผ ์ฌ๋ฆด ๊ฒ์
๋๋ค. ์ดํ์๋ snapshot_download
๋ฅผ ์ฌ์ฉํ์ฌ ํด๋น ์ธ๋ฑ์ค๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.
from huggingface_hub import HfApi, notebook_login, snapshot_download
notebook_login()
from huggingface_hub import HfApi
hf_id = "๋น์ ์ ํ๊น
ํ์ด์ค ํ๋ธ ์์ด๋๋ฅผ ์
๋ ฅํ์ธ์."
api = HfApi()
api.create_repo(f"{hf_id}/faiss_embeddings", repo_type="dataset")
api.upload_folder(
folder_path="./embeddings",
repo_id=f"{hf_id}/faiss_embeddings",
repo_type="dataset",
)
snapshot_download(repo_id=f"{hf_id}/faiss_embeddings", repo_type="dataset", local_dir="downloaded_embeddings")
load_faiss_index
๋ฅผ ์ฌ์ฉํ์ฌ ์๋ฒ ๋ฉ์ด ์๋ ๋ฐ์ดํฐ์
์ ์๋ฒ ๋ฉ์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.
ds = ds["train"]
ds.load_faiss_index("embeddings", "./downloaded_embeddings/embeddings.faiss")
# ๋ค์ ์ถ๋ก ํฉ๋๋ค.
prmt = "people under the rain"
prmt_embedding = (
model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0]
.detach()
.cpu()
.numpy()
)
scores, retrieved_examples = ds.get_nearest_examples("embeddings", prmt_embedding, k=1)
>>> display(retrieved_examples["image"][0])