{ "cells": [ { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "from typing import List\n", "import requests\n", "from PIL import Image\n", "from transformers import CLIPModel, CLIPProcessor, CLIPFeatureExtractor\n", "import torch" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n", "image = Image.open(requests.get(url, stream=True).raw)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ClipWrapper:\n", " def __init__(self):\n", " self.model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", " self.processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", "\n", " def images2vec(self, images: List[Image.Image]) -> torch.Tensor:\n", " inputs = self.processor(images=images, return_tensors=\"pt\")\n", " with torch.no_grad():\n", " model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}\n", " image_embeds = self.model.vision_model(**model_inputs)\n", " clip_vectors = self.model.visual_projection(image_embeds[1])\n", " return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)\n", "\n", " def texts2vec(self, texts: List[str]) -> torch.Tensor:\n", " inputs = self.processor(text=texts, return_tensors=\"pt\", padding=True)\n", " with torch.no_grad():\n", " model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}\n", " text_embeds = self.model.text_model(**model_inputs)\n", " text_vectors = self.model.text_projection(text_embeds[1])\n", " return text_vectors / text_vectors.norm(dim=-1, keepdim=True)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 512])" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def images2vec(images: List[Image.Image]) -> torch.Tensor:\n", " inputs = processor(images=images, return_tensors=\"pt\")\n", " with torch.no_grad():\n", " model_inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", " image_embeds = model.vision_model(**model_inputs)\n", " clip_vectors = model.visual_projection(image_embeds[1])\n", " return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)\n", "\n", "\n", "result = images2vec([image, image])\n", "result.shape" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 512])" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def texts2vec(texts: List[str]) -> torch.Tensor:\n", " inputs = processor(text=texts, return_tensors=\"pt\", padding=True)\n", " with torch.no_grad():\n", " model_inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", " text_embeds = model.text_model(**model_inputs)\n", " text_vectors = model.text_projection(text_embeds[1])\n", " return text_vectors / text_vectors.norm(dim=-1, keepdim=True)\n", "\n", "\n", "texts2vec([\"a photo of a cat\", \"a photo of a dog\"]).shape" ] } ], "metadata": { "kernelspec": { "display_name": "semvideo-hackathon-230523", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }