{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "32b7d029-64ce-4361-acde-dc72d67637b7", "metadata": { "tags": [] }, "outputs": [], "source": [ "import copy\n", "import torch\n", "import torch.nn as nn\n", "import clip\n", "from transformers import CLIPProcessor\n", "from huggingface_hub import PyTorchModelHubMixin\n", "from transformers import PretrainedConfig\n", "\n", "class CSDCLIPConfig(PretrainedConfig):\n", " model_type = \"csd_clip\"\n", "\n", " def __init__(\n", " self,\n", " name=\"csd_large\",\n", " embedding_dim=1024,\n", " feature_dim=1024,\n", " content_dim=768,\n", " style_dim=768,\n", " content_proj_head=\"default\",\n", " **kwargs\n", " ):\n", " super().__init__(**kwargs)\n", " self.name = name\n", " self.embedding_dim = embedding_dim\n", " self.content_proj_head = content_proj_head\n", " self.task_specific_params = None # Add this line\n", "\n", "class CSD_CLIP(nn.Module, PyTorchModelHubMixin):\n", " \"\"\"backbone + projection head\"\"\"\n", " def __init__(self, name='vit_large',content_proj_head='default'):\n", " super(CSD_CLIP, self).__init__()\n", " self.content_proj_head = content_proj_head\n", " if name == 'vit_large':\n", " clipmodel, _ = clip.load(\"ViT-L/14\")\n", " self.backbone = clipmodel.visual\n", " self.embedding_dim = 1024\n", " self.feature_dim = 1024\n", " self.content_dim = 768\n", " self.style_dim = 768\n", " self.name = \"csd_large\"\n", " elif name == 'vit_base':\n", " clipmodel, _ = clip.load(\"ViT-B/16\")\n", " self.backbone = clipmodel.visual\n", " self.embedding_dim = 768 \n", " self.feature_dim = 512\n", " self.content_dim = 512\n", " self.style_dim = 512\n", " self.name = \"csd_base\"\n", " else:\n", " raise Exception('This model is not implemented')\n", "\n", " self.last_layer_style = copy.deepcopy(self.backbone.proj)\n", " self.last_layer_content = copy.deepcopy(self.backbone.proj)\n", "\n", " self.backbone.proj = None\n", " \n", " self.config = CSDCLIPConfig(\n", " name=self.name,\n", " embedding_dim=self.embedding_dim,\n", " feature_dim=self.feature_dim,\n", " content_dim=self.content_dim,\n", " style_dim=self.style_dim,\n", " content_proj_head=self.content_proj_head\n", " )\n", "\n", " def get_config(self):\n", " return self.config.to_dict()\n", "\n", " @property\n", " def dtype(self):\n", " return self.backbone.conv1.weight.dtype\n", " \n", " @property\n", " def device(self):\n", " return next(self.parameters()).device\n", "\n", " def forward(self, input_data):\n", " \n", " feature = self.backbone(input_data)\n", "\n", " style_output = feature @ self.last_layer_style\n", " style_output = nn.functional.normalize(style_output, dim=1, p=2)\n", "\n", " content_output = feature @ self.last_layer_content\n", " content_output = nn.functional.normalize(content_output, dim=1, p=2)\n", " \n", " return feature, content_output, style_output\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "model = CSD_CLIP.from_pretrained(\"yuxi-liu-wired/CSD\")\n", "model.to(device);" ] }, { "cell_type": "code", "execution_count": null, "id": "bbd750f6-fde9-48ed-a7d8-42ee5d31429d", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "from transformers import Pipeline\n", "from typing import Union, List\n", "from PIL import Image\n", "\n", "class CSDCLIPPipeline(Pipeline):\n", " def __init__(self, model, processor, device=None):\n", " if device is None:\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " super().__init__(model=model, tokenizer=None, device=device)\n", " self.processor = processor\n", "\n", " def _sanitize_parameters(self, **kwargs):\n", " return {}, {}, {}\n", "\n", " def preprocess(self, images):\n", " if isinstance(images, (str, Image.Image)):\n", " images = [images]\n", " \n", " processed = self.processor(images=images, return_tensors=\"pt\", padding=True, truncation=True)\n", " return {k: v.to(self.device) for k, v in processed.items()}\n", "\n", " def _forward(self, model_inputs):\n", " pixel_values = model_inputs['pixel_values'].to(self.model.dtype)\n", " with torch.no_grad():\n", " features, content_output, style_output = self.model(pixel_values)\n", " return {\"features\": features, \"content_output\": content_output, \"style_output\": style_output}\n", "\n", " def postprocess(self, model_outputs):\n", " return {\n", " \"features\": model_outputs[\"features\"].cpu().numpy(),\n", " \"content_output\": model_outputs[\"content_output\"].cpu().numpy(),\n", " \"style_output\": model_outputs[\"style_output\"].cpu().numpy()\n", " }\n", "\n", " def __call__(self, images: Union[str, List[str], Image.Image, List[Image.Image]]):\n", " return super().__call__(images)\n", "\n", "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n", "pipeline = CSDCLIPPipeline(model=model, processor=processor, device=device)" ] }, { "cell_type": "code", "execution_count": 3, "id": "4107999a-c48c-4cb4-9247-9836dfb27e98", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing images: 100%|█████████████████████████████████████████████████████████████| 900/900 [01:09<00:00, 12.86it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Processing complete. Results saved to 'processed_dataset.parquet'\n" ] } ], "source": [ "import io\n", "from PIL import Image\n", "from datasets import load_dataset\n", "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "def to_jpeg(image):\n", " buffered = io.BytesIO()\n", " if image.mode not in (\"RGB\"):\n", " image = image.convert(\"RGB\")\n", " image.save(buffered, format='JPEG')\n", " return buffered.getvalue() \n", "\n", "def scale_image(image, max_resolution):\n", " if max(image.width, image.height) > max_resolution:\n", " image = image.resize((max_resolution, int(image.height * max_resolution / image.width)))\n", " return image\n", "\n", "def process_dataset(pipeline, dataset_name, dataset_size=900, max_resolution=192):\n", " dataset = load_dataset(dataset_name, split='train')\n", " dataset = dataset.select(range(dataset_size))\n", " \n", " # Print the column names\n", " print(\"Dataset columns:\", dataset.column_names)\n", " \n", " # Initialize lists to store results\n", " embeddings = []\n", " jpeg_images = []\n", " \n", " # Process each item in the dataset\n", " for item in tqdm(dataset, desc=\"Processing images\"):\n", " try:\n", " img = item['image']\n", " \n", " # If img is a string (file path), load the image\n", " if isinstance(img, str):\n", " img = Image.open(img)\n", "\n", "\n", " output = pipeline(img)\n", " style_output = output[\"style_output\"].squeeze(0)\n", " \n", " img = scale_image(img, max_resolution)\n", " jpeg_img = to_jpeg(img)\n", " \n", " # Append results to lists\n", " embeddings.append(style_output)\n", " jpeg_images.append(jpeg_img)\n", " except Exception as e:\n", " print(f\"Error processing item: {e}\")\n", " \n", " # Create a DataFrame with the results\n", " df = pd.DataFrame({\n", " 'embedding': embeddings,\n", " 'image': jpeg_images\n", " })\n", " \n", " df.to_parquet('processed_dataset.parquet')\n", " print(\"Processing complete. Results saved to 'processed_dataset.parquet'\")\n", "\n", "process_dataset(pipeline, \"yuxi-liu-wired/style-content-grid-SDXL\", \n", " dataset_size=900, max_resolution=192)" ] }, { "cell_type": "code", "execution_count": null, "id": "066ec067-edb1-4110-a0fe-8d7c97311790", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:diffgan]", "language": "python", "name": "conda-env-diffgan-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }