{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from CLIP.clip import clip" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "clip_model_modify, clip_preprocess_modify = clip.load(\"../pretrained/clip_best.pth\", device=torch.device('cpu'), jit=False)\n", "# ./clip_weights/best_model_all_feature.pt\n", "clip_model_ori, clip_preprocess_ori = clip.load(\"../ViT-B-32.pt\", device=torch.device('cpu'), jit=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compare_weights(model1, model2):\n", " different_layers = []\n", " for name1, param1 in model1.named_parameters():\n", " param2 = model2.state_dict()[name1]\n", " # print(param2)\n", " if not torch.equal(param1, param2):\n", " different_layers.append(name1)\n", " return different_layers\n", "compare_weights(clip_model_modify, clip_model_ori)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def topK_process(model, text):\n", " # Encode and normalize the search query using CLIP\n", " text_token = clip.tokenize(text, truncate=True)\n", " tokens = text.split(' ')\n", " text_encoded, weight = model.encode_text(text_token)\n", "\n", " text_encoded /= text_encoded.norm(dim=-1, keepdim=True)\n", " attention_weights = weight[-1][0][1+len(tokens)][:2+len(tokens)][1:][:-1]\n", " # attention_weights = weight[-1][range(len(weight[-1])), tokens_lens][:, :1+max(tokens_lens)][:, 1:][:, :-1]\n", " return text_encoded, attention_weights\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# clip_text = 'a person passes something to the right'\n", "# clip_text_perb = 'a native passes something to the right'\n", "\n", "# clip_text = 'person is walking normally in a circle'\n", "# clip_text_perb = 'human is walking usually in a loop'\n", "\n", "# clip_text = 'a man kicks something or someone with his left leg'\n", "# clip_text_perb = 'a human boots something or someone with his left leg'\n", "\n", "# Walking forward in an even pace \n", "# Going ahead in an even pace\n", "\n", "# clip_text = 'a man jumps forward and swings his arms'\n", "# clip_text_perb = 'a native bounds ahead and waves his arms'\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "clip_text = 'person is sitting down and looking around'\n", "clip_text_perb = 'native is seating down and looking around'" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "import numpy\n", "def visual_weights(model, clip_text=clip_text, clip_text_perb=clip_text_perb):\n", " model.eval()\n", " text, weight_ori = topK_process(model, clip_text)\n", " text_perb, weight_perb = topK_process(model, clip_text_perb)\n", " weight_ori_v = weight_ori.detach().cpu().numpy()/weight_ori.detach().cpu().numpy().sum()\n", " # print(weight_ori_v.sum())\n", " weight_perb_v = weight_perb.detach().cpu().numpy()/weight_perb.detach().cpu().numpy().sum()\n", " # print(f\"text:{clip_text}, \\n weight_ori:{weight_ori_v},\\n text_perb:{clip_text_perb}, \\n weight_perb:{weight_perb_v}\")\n", " return clip_text, weight_ori_v, clip_text_perb, weight_perb_v" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Rectangle\n", "\n", "def generate_colored_text_image(text, attention_weights):\n", " # text = text.split(' ')\n", " # attention_weights = [float(i) for i in attention_weights]\n", " fig, ax = plt.subplots(figsize=(len(text)+1, 1))\n", " ax.set_axis_off()\n", " \n", " # 计算文本块的数量\n", " num_words = len(text)\n", " \n", " # 计算每个文本块的宽度\n", " word_width = 0.95 / num_words # 减少间距\n", " \n", " # 计算最小和最大的权重值\n", " min_weight = min(attention_weights)\n", " max_weight = max(attention_weights)\n", " \n", " # 设置颜色\n", " base_color = (1, 0.5, 0.5) # 基础颜色为浅红色\n", " color_map = [(1, 0.95 - 0.3 * (weight - min_weight) / (max_weight - min_weight), 0.95 - 0.3 * (weight - min_weight) / (max_weight - min_weight), 0.8) for weight in attention_weights] # 根据权重计算颜色\n", " \n", " # 生成文本并设置背景颜色\n", " x_position = 0\n", " for word, color in zip(text, color_map):\n", " rect = Rectangle((x_position, 0), word_width, 1, facecolor=color)\n", " ax.add_patch(rect)\n", " ax.text(x_position + word_width / 2, 0.5, word, ha='center', va='center', fontsize=14, color='black') # 增大字体\n", " x_position += word_width\n", " \n", " plt.xlim(0, 1)\n", " plt.ylim(0, 1)\n", " plt.savefig('text_attention.png', dpi=300)\n", " plt.show()\n", "\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "clip_text, weight_ori_v, clip_text_perb, weight_perb_v = visual_weights(clip_model_ori)\n", "text = clip_text.split(' ')\n", "attention_weights = [float(i) for i in weight_ori_v]\n", "\n", "generate_colored_text_image(text, attention_weights)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "text = clip_text_perb.split(' ')\n", "attention_weights = [float(i) for i in weight_perb_v]\n", "\n", "generate_colored_text_image(text, attention_weights)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "clip_text, weight_ori_v, clip_text_perb, weight_perb_v = visual_weights(clip_model_modify)\n", "text = clip_text.split(' ')\n", "attention_weights = [float(i) for i in weight_ori_v]\n", "generate_colored_text_image(text, attention_weights)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "text = clip_text_perb.split(' ')\n", "attention_weights = [float(i) for i in weight_perb_v]\n", "generate_colored_text_image(text, attention_weights)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3352644313389532" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from scipy.spatial.distance import jensenshannon\n", "def jsd_cal(model, clip_text, clip_text_perb):\n", " clip_text, weight_ori_v, clip_text_perb, weight_perb_v = visual_weights(model, clip_text, clip_text_perb)\n", " normalized_attention = weight_ori_v / weight_ori_v.sum()\n", " normalized_attention_perb = weight_perb_v / weight_perb_v.sum()\n", " jsd = jensenshannon(normalized_attention, normalized_attention_perb, base=2)\n", " return jsd\n", "\n", "jsd_cal(clip_model_ori, clip_text, clip_text_perb)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.08214502506975593" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jsd_cal(clip_model_modify, clip_text, clip_text_perb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "llm2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 2 }