{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "62c37427", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModel, AutoTokenizer\n", "from sentence_transformers import SentenceTransformer\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "import pandas as pd\n", "import numpy as np\n", "import torch.nn.functional as F\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "id": "ca8d35e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "sentence_model = SentenceTransformer('all-MiniLM-L6-v2')\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 3, "id": "27935f40", "metadata": {}, "outputs": [], "source": [ "class ImageDataset(Dataset):\n", " def __init__(self, csv_file, transform=None):\n", " self.annotations = csv_file\n", " self.transform=transform\n", " \n", " def __len__(self):\n", " return len(self.annotations)\n", " \n", " def __getitem__(self,index):\n", " img_desc = self.annotations.iloc[index, 2]\n", "\n", " label=torch.tensor(int(self.annotations.iloc[index, 3]))\n", " \n", " if self.transform:\n", " img_desc = self.transform(img_desc)\n", " \n", " return (img_desc, label)" ] }, { "cell_type": "code", "execution_count": 4, "id": "d96cfab6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "81890\n" ] } ], "source": [ "df = pd.read_csv('image_desc.csv')\n", "dataset = ImageDataset(df)\n", "train_size = int(0.85 * len(dataset))\n", "test_size = len(dataset) - train_size\n", "train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])\n", "print(len(dataset))" ] }, { "cell_type": "code", "execution_count": 12, "id": "d12e4992", "metadata": {}, "outputs": [], "source": [ "batch_size=16\n", "train_loader=DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", "test_loader=DataLoader(test_set, batch_size=batch_size, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 54, "id": "4e1f90e0", "metadata": {}, "outputs": [], "source": [ "class MyModel(nn.Module):\n", " def __init__(self, sentence_model, hidden_dim, output_dim):\n", " super(MyModel, self).__init__()\n", " self.sentence_model = sentence_model\n", " self.fc1 = nn.Linear(384, hidden_dim)\n", " self.fc2 = nn.Linear(hidden_dim, output_dim)\n", " self.sig = nn.Sigmoid()\n", "\n", " def forward(self, x):\n", " sentence_embeddings = self.sentence_model.encode(x, convert_to_tensor=True)\n", " sentence_embeddings = sentence_embeddings.to(device)\n", " hidden = self.fc1(sentence_embeddings)\n", " hidden = F.relu(hidden)\n", " logits = self.fc2(hidden)\n", "# logits = torch.clamp(logits, min=1e-5)\n", " logits = self.sig(logits)\n", " return logits\n", "\n", "output_dim = 102\n", "hidden_dim = 256\n", "\n", "model = MyModel(sentence_model, hidden_dim, output_dim).to(device)" ] }, { "cell_type": "code", "execution_count": 37, "id": "85c41c63", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:42<00:00, 42.36it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.0000, device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "min = torch.tensor(1).to(device)\n", "similarity = nn.CosineSimilarity(dim = 0)\n", "for sample_batch, sample_label in tqdm(train_loader):\n", " i = sample_batch[0]\n", " j = sample_batch[1]\n", " output_i = model(i)\n", " output_j = model(j)\n", " sim_i_j = similarity(output_i, output_j)\n", " if sim_i_j < min:\n", " min = sim_i_j\n", " \n", "print(min)" ] }, { "cell_type": "code", "execution_count": 55, "id": "e99a5150", "metadata": {}, "outputs": [], "source": [ "criterion = nn.CrossEntropyLoss()\n", "# criterion = nn.MSELoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.005)" ] }, { "cell_type": "code", "execution_count": 56, "id": "7957341a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.25it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1/4, Loss: 1116.4719812870026\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.90it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 2/4, Loss: 1087.523635149002\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 3/4, Loss: 1079.509438186884\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:07<00:00, 64.31it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 4/4, Loss: 1074.7653084248304\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "num_epochs = 4\n", "for epoch in range(num_epochs):\n", " model.train()\n", " losses = []\n", "\n", " for i, (sentences_batch, labels_batch) in enumerate(tqdm(train_loader)):\n", " labels_batch = labels_batch.to(device)\n", " labels_batch = F.one_hot(labels_batch, num_classes = 102).float()\n", " optimizer.zero_grad()\n", " # Forward pass\n", " logits = model(sentences_batch).float()\n", " loss = criterion(logits, labels_batch)\n", " \n", " # Backward pass and optimization\n", " loss.backward()\n", " optimizer.step()\n", " curr_loss = loss.item()\n", " losses.append(curr_loss)\n", " \n", " running_loss = sum(losses)\n", " \n", " # Print the average loss for every epoch\n", " epoch_loss = running_loss / batch_size\n", " print(f\"Epoch: {epoch+1}/{num_epochs}, Loss: {epoch_loss}\")\n" ] }, { "cell_type": "code", "execution_count": 47, "id": "4ecaab5d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:33<00:00, 46.66it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor(0., device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "for sample_batch, sample_label in tqdm(train_loader):\n", " i = sample_batch[0]\n", " j = sample_batch[1]\n", " output_i = model(i)\n", " output_j = model(j)\n", " sim_i_j = similarity(output_i, output_j)\n", " if sim_i_j < min:\n", " min = sim_i_j\n", " \n", "print(min)" ] }, { "cell_type": "code", "execution_count": 57, "id": "a0d95f76", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:04<00:00, 67.76it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.2659109846852283\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "model.eval()\n", "total_correct = 0\n", "total_samples = 0\n", "\n", "with torch.no_grad():\n", " for i, (sentences_batch, labels_batch) in enumerate(tqdm(train_loader)):\n", " labels_batch = labels_batch.to(device)\n", "# labels_batch = F.one_hot(labels_batch, num_classes = 102).float()\n", " logits = model(sentences_batch).float()\n", " predicted = torch.argmax(logits, dim = 1)\n", " total_samples += labels_batch.size(0)\n", " total_correct += (predicted == labels_batch).sum().item()\n", "\n", "accuracy = total_correct / total_samples\n", "print(\"Accuracy:\", accuracy)" ] }, { "cell_type": "code", "execution_count": 58, "id": "da1763b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ENTER DESCRIPTION pink\n", "tensor([9.6708e-15, 1.0179e-04, 4.2242e-08, 1.3063e-15, 8.8056e-03, 0.0000e+00,\n", " 1.6553e-14, 8.9271e-33, 9.0644e-27, 5.9910e-19, 2.2721e-24, 7.7432e-03,\n", " 3.9587e-36, 7.1618e-07, 2.7430e-08, 0.0000e+00, 0.0000e+00, 1.4562e-03,\n", " 9.8114e-06, 9.2844e-24, 7.8520e-33, 2.9296e-22, 3.5067e-13, 1.3316e-05,\n", " 7.7768e-11, 9.2201e-39, 5.0639e-22, 1.6904e-19, 3.2689e-35, 1.0034e-14,\n", " 9.8686e-01, 4.1330e-05, 6.3048e-01, 9.5960e-23, 1.2662e-14, 2.4540e-22,\n", " 1.4413e-08, 9.9928e-01, 2.8299e-02, 4.9763e-10, 2.7364e-04, 9.9878e-01,\n", " 0.0000e+00, 9.9998e-01, 6.7328e-02, 2.9939e-13, 1.9145e-17, 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9998e-01, 1.1818e-30, 2.2513e-22,\n", " 0.0000e+00, 1.0346e-32, 8.8656e-21, 9.9353e-01, 4.3037e-03, 8.6023e-39,\n", " 3.6964e-10, 3.3164e-21, 1.9611e-15, 0.0000e+00, 3.7135e-38, 1.3163e-34,\n", " 1.8906e-07, 7.0084e-30, 1.0882e-20, 2.6501e-33, 8.9597e-39, 5.0791e-37,\n", " 1.0000e+00, 5.7929e-03, 1.3252e-03, 1.4498e-23, 1.3656e-02, 2.0226e-07,\n", " 8.3005e-01, 8.4326e-14, 2.1941e-03, 3.8749e-28, 9.8803e-01, 9.9992e-01,\n", " 4.3195e-11, 7.0360e-01, 1.0000e+00, 1.5408e-02, 9.9689e-01, 8.0569e-15,\n", " 1.4282e-22, 9.6706e-03, 4.9712e-03, 4.8348e-05, 1.2486e-05, 9.9923e-01,\n", " 6.3526e-06, 1.7522e-01, 8.8239e-01, 2.0713e-11, 2.2530e-20, 2.1032e-05],\n", " device='cuda:0', grad_fn=)\n", "tensor(72, device='cuda:0')\n" ] } ], "source": [ "sentence = input(\"ENTER DESCRIPTION \")\n", "output = model(sentence)\n", "predicted = torch.argmax(output)\n", "print(output)\n", "print(predicted)" ] }, { "cell_type": "code", "execution_count": 59, "id": "bcf4856c", "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"sentence_embedding.pth\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5f6f22f4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }