Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +3 -0
- .ipynb_checkpoints/FLOWER_GAN_64-checkpoint.ipynb +0 -0
- .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- .ipynb_checkpoints/VAE_64-checkpoint.ipynb +0 -0
- Conditional_Diffusion_64.ipynb +3 -0
- FLOWER_Conditional_GAN_64.ipynb +0 -0
- FLOWER_GAN_128.ipynb +3 -0
- FLOWER_GAN_64.ipynb +0 -0
- Model_Saved_States/CGAN_64_discriminator.pth +3 -0
- Model_Saved_States/CGAN_64_generator.pth +3 -0
- Model_Saved_States/GAN_128_discriminator.pth +3 -0
- Model_Saved_States/GAN_128_generator.pth +3 -0
- Model_Saved_States/conditional_diffusion_64.pth +3 -0
- Model_Saved_States/diffusion_64.pth +3 -0
- Model_Saved_States/sentence_embedding.pth +3 -0
- README.md +2 -8
- Sentence_Embeddings.ipynb +454 -0
- Unconditional_Diffusion_64.ipynb +0 -0
- Untitled.ipynb +263 -0
- VAE_64.ipynb +0 -0
- app.py +182 -0
- flagged/log.csv +2 -0
- image_desc.csv +0 -0
- jpg/flowers/image_00001.jpg +0 -0
- jpg/flowers/image_00002.jpg +0 -0
- jpg/flowers/image_00003.jpg +0 -0
- jpg/flowers/image_00004.jpg +0 -0
- jpg/flowers/image_00005.jpg +0 -0
- jpg/flowers/image_00006.jpg +0 -0
- jpg/flowers/image_00007.jpg +0 -0
- jpg/flowers/image_00008.jpg +0 -0
- jpg/flowers/image_00009.jpg +0 -0
- jpg/flowers/image_00010.jpg +0 -0
- jpg/flowers/image_00011.jpg +0 -0
- jpg/flowers/image_00012.jpg +0 -0
- jpg/flowers/image_00013.jpg +0 -0
- jpg/flowers/image_00014.jpg +0 -0
- jpg/flowers/image_00015.jpg +0 -0
- jpg/flowers/image_00016.jpg +0 -0
- jpg/flowers/image_00017.jpg +0 -0
- jpg/flowers/image_00018.jpg +0 -0
- jpg/flowers/image_00019.jpg +0 -0
- jpg/flowers/image_00020.jpg +0 -0
- jpg/flowers/image_00021.jpg +0 -0
- jpg/flowers/image_00022.jpg +0 -0
- jpg/flowers/image_00023.jpg +0 -0
- jpg/flowers/image_00024.jpg +0 -0
- jpg/flowers/image_00025.jpg +0 -0
- jpg/flowers/image_00026.jpg +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Conditional_Diffusion_64.ipynb filter=lfs diff=lfs merge=lfs -text
|
37 |
+
FLOWER_GAN_128.ipynb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
jpg/
|
2 |
+
Model_Saved_State/
|
3 |
+
image*
|
.ipynb_checkpoints/FLOWER_GAN_64-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
.ipynb_checkpoints/VAE_64-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Conditional_Diffusion_64.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e03d955344baf2568ba3de198643ce2b3dc16402fc0b3ee309e869c4ced195a
|
3 |
+
size 17604999
|
FLOWER_Conditional_GAN_64.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
FLOWER_GAN_128.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8449a476099b632280ffb66c13414f45a98e196040895f870188a97b16ade374
|
3 |
+
size 133325097
|
FLOWER_GAN_64.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Model_Saved_States/CGAN_64_discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43acdc673ec2922267e3e2b8b398629c94f2bdd8004faf4e47308b66c12fd8eb
|
3 |
+
size 16814806
|
Model_Saved_States/CGAN_64_generator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:27dd926206fd0acdf0071d71cd5263f634ba7f9e778f5bd8ce20ab1a9ed51e5c
|
3 |
+
size 25469566
|
Model_Saved_States/GAN_128_discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69607ce6d1e68b66a80907c32b6e0adaebae049da7a51d8447a85514cfdc9f58
|
3 |
+
size 17113810
|
Model_Saved_States/GAN_128_generator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9a804f21b9b4a63ef72ca7e11bfeb767d636403c238d36ee06d7b0127d2eccf
|
3 |
+
size 34145406
|
Model_Saved_States/conditional_diffusion_64.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4acb997e447b573366c76f441abafa6f56afd1dfedd0de288bc384a5d329b256
|
3 |
+
size 181400873
|
Model_Saved_States/diffusion_64.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21a862c67c65539ec93da68a5a278d1205c89d5d0c858bc9d3614aa2427ecf7d
|
3 |
+
size 89200917
|
Model_Saved_States/sentence_embedding.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de5f9fa514895fa02d90200de434906018832a2725bb30e98f81c67ed57c2059
|
3 |
+
size 91405603
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.38.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: first_space
|
3 |
+
app_file: app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.38.0
|
|
|
|
|
6 |
---
|
|
|
|
Sentence_Embeddings.ipynb
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "62c37427",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from transformers import AutoModel, AutoTokenizer\n",
|
11 |
+
"from sentence_transformers import SentenceTransformer\n",
|
12 |
+
"import torch\n",
|
13 |
+
"import torch.nn as nn\n",
|
14 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
15 |
+
"import pandas as pd\n",
|
16 |
+
"import numpy as np\n",
|
17 |
+
"import torch.nn.functional as F\n",
|
18 |
+
"from tqdm import tqdm"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"id": "ca8d35e3",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"cuda\n"
|
32 |
+
]
|
33 |
+
}
|
34 |
+
],
|
35 |
+
"source": [
|
36 |
+
"sentence_model = SentenceTransformer('all-MiniLM-L6-v2')\n",
|
37 |
+
"\n",
|
38 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
39 |
+
"print(device)"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 3,
|
45 |
+
"id": "27935f40",
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"class ImageDataset(Dataset):\n",
|
50 |
+
" def __init__(self, csv_file, transform=None):\n",
|
51 |
+
" self.annotations = csv_file\n",
|
52 |
+
" self.transform=transform\n",
|
53 |
+
" \n",
|
54 |
+
" def __len__(self):\n",
|
55 |
+
" return len(self.annotations)\n",
|
56 |
+
" \n",
|
57 |
+
" def __getitem__(self,index):\n",
|
58 |
+
" img_desc = self.annotations.iloc[index, 2]\n",
|
59 |
+
"\n",
|
60 |
+
" label=torch.tensor(int(self.annotations.iloc[index, 3]))\n",
|
61 |
+
" \n",
|
62 |
+
" if self.transform:\n",
|
63 |
+
" img_desc = self.transform(img_desc)\n",
|
64 |
+
" \n",
|
65 |
+
" return (img_desc, label)"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": 4,
|
71 |
+
"id": "d96cfab6",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [
|
74 |
+
{
|
75 |
+
"name": "stdout",
|
76 |
+
"output_type": "stream",
|
77 |
+
"text": [
|
78 |
+
"81890\n"
|
79 |
+
]
|
80 |
+
}
|
81 |
+
],
|
82 |
+
"source": [
|
83 |
+
"df = pd.read_csv('image_desc.csv')\n",
|
84 |
+
"dataset = ImageDataset(df)\n",
|
85 |
+
"train_size = int(0.85 * len(dataset))\n",
|
86 |
+
"test_size = len(dataset) - train_size\n",
|
87 |
+
"train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])\n",
|
88 |
+
"print(len(dataset))"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 12,
|
94 |
+
"id": "d12e4992",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"batch_size=16\n",
|
99 |
+
"train_loader=DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
|
100 |
+
"test_loader=DataLoader(test_set, batch_size=batch_size, shuffle=True)"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 54,
|
106 |
+
"id": "4e1f90e0",
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"class MyModel(nn.Module):\n",
|
111 |
+
" def __init__(self, sentence_model, hidden_dim, output_dim):\n",
|
112 |
+
" super(MyModel, self).__init__()\n",
|
113 |
+
" self.sentence_model = sentence_model\n",
|
114 |
+
" self.fc1 = nn.Linear(384, hidden_dim)\n",
|
115 |
+
" self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
|
116 |
+
" self.sig = nn.Sigmoid()\n",
|
117 |
+
"\n",
|
118 |
+
" def forward(self, x):\n",
|
119 |
+
" sentence_embeddings = self.sentence_model.encode(x, convert_to_tensor=True)\n",
|
120 |
+
" sentence_embeddings = sentence_embeddings.to(device)\n",
|
121 |
+
" hidden = self.fc1(sentence_embeddings)\n",
|
122 |
+
" hidden = F.relu(hidden)\n",
|
123 |
+
" logits = self.fc2(hidden)\n",
|
124 |
+
"# logits = torch.clamp(logits, min=1e-5)\n",
|
125 |
+
" logits = self.sig(logits)\n",
|
126 |
+
" return logits\n",
|
127 |
+
"\n",
|
128 |
+
"output_dim = 102\n",
|
129 |
+
"hidden_dim = 256\n",
|
130 |
+
"\n",
|
131 |
+
"model = MyModel(sentence_model, hidden_dim, output_dim).to(device)"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 37,
|
137 |
+
"id": "85c41c63",
|
138 |
+
"metadata": {},
|
139 |
+
"outputs": [
|
140 |
+
{
|
141 |
+
"name": "stderr",
|
142 |
+
"output_type": "stream",
|
143 |
+
"text": [
|
144 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:42<00:00, 42.36it/s]"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"name": "stdout",
|
149 |
+
"output_type": "stream",
|
150 |
+
"text": [
|
151 |
+
"tensor(1.0000, device='cuda:0', grad_fn=<SumBackward1>)\n"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"name": "stderr",
|
156 |
+
"output_type": "stream",
|
157 |
+
"text": [
|
158 |
+
"\n"
|
159 |
+
]
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"source": [
|
163 |
+
"min = torch.tensor(1).to(device)\n",
|
164 |
+
"similarity = nn.CosineSimilarity(dim = 0)\n",
|
165 |
+
"for sample_batch, sample_label in tqdm(train_loader):\n",
|
166 |
+
" i = sample_batch[0]\n",
|
167 |
+
" j = sample_batch[1]\n",
|
168 |
+
" output_i = model(i)\n",
|
169 |
+
" output_j = model(j)\n",
|
170 |
+
" sim_i_j = similarity(output_i, output_j)\n",
|
171 |
+
" if sim_i_j < min:\n",
|
172 |
+
" min = sim_i_j\n",
|
173 |
+
" \n",
|
174 |
+
"print(min)"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": 55,
|
180 |
+
"id": "e99a5150",
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [],
|
183 |
+
"source": [
|
184 |
+
"criterion = nn.CrossEntropyLoss()\n",
|
185 |
+
"# criterion = nn.MSELoss()\n",
|
186 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=0.005)"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": 56,
|
192 |
+
"id": "7957341a",
|
193 |
+
"metadata": {},
|
194 |
+
"outputs": [
|
195 |
+
{
|
196 |
+
"name": "stderr",
|
197 |
+
"output_type": "stream",
|
198 |
+
"text": [
|
199 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.25it/s]\n"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"name": "stdout",
|
204 |
+
"output_type": "stream",
|
205 |
+
"text": [
|
206 |
+
"Epoch: 1/4, Loss: 1116.4719812870026\n"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"name": "stderr",
|
211 |
+
"output_type": "stream",
|
212 |
+
"text": [
|
213 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.90it/s]\n"
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"name": "stdout",
|
218 |
+
"output_type": "stream",
|
219 |
+
"text": [
|
220 |
+
"Epoch: 2/4, Loss: 1087.523635149002\n"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"name": "stderr",
|
225 |
+
"output_type": "stream",
|
226 |
+
"text": [
|
227 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:06<00:00, 65.20it/s]\n"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"name": "stdout",
|
232 |
+
"output_type": "stream",
|
233 |
+
"text": [
|
234 |
+
"Epoch: 3/4, Loss: 1079.509438186884\n"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"name": "stderr",
|
239 |
+
"output_type": "stream",
|
240 |
+
"text": [
|
241 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:07<00:00, 64.31it/s]"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"name": "stdout",
|
246 |
+
"output_type": "stream",
|
247 |
+
"text": [
|
248 |
+
"Epoch: 4/4, Loss: 1074.7653084248304\n"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"name": "stderr",
|
253 |
+
"output_type": "stream",
|
254 |
+
"text": [
|
255 |
+
"\n"
|
256 |
+
]
|
257 |
+
}
|
258 |
+
],
|
259 |
+
"source": [
|
260 |
+
"num_epochs = 4\n",
|
261 |
+
"for epoch in range(num_epochs):\n",
|
262 |
+
" model.train()\n",
|
263 |
+
" losses = []\n",
|
264 |
+
"\n",
|
265 |
+
" for i, (sentences_batch, labels_batch) in enumerate(tqdm(train_loader)):\n",
|
266 |
+
" labels_batch = labels_batch.to(device)\n",
|
267 |
+
" labels_batch = F.one_hot(labels_batch, num_classes = 102).float()\n",
|
268 |
+
" optimizer.zero_grad()\n",
|
269 |
+
" # Forward pass\n",
|
270 |
+
" logits = model(sentences_batch).float()\n",
|
271 |
+
" loss = criterion(logits, labels_batch)\n",
|
272 |
+
" \n",
|
273 |
+
" # Backward pass and optimization\n",
|
274 |
+
" loss.backward()\n",
|
275 |
+
" optimizer.step()\n",
|
276 |
+
" curr_loss = loss.item()\n",
|
277 |
+
" losses.append(curr_loss)\n",
|
278 |
+
" \n",
|
279 |
+
" running_loss = sum(losses)\n",
|
280 |
+
" \n",
|
281 |
+
" # Print the average loss for every epoch\n",
|
282 |
+
" epoch_loss = running_loss / batch_size\n",
|
283 |
+
" print(f\"Epoch: {epoch+1}/{num_epochs}, Loss: {epoch_loss}\")\n"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": 47,
|
289 |
+
"id": "4ecaab5d",
|
290 |
+
"metadata": {},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stderr",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:33<00:00, 46.66it/s]"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"name": "stdout",
|
301 |
+
"output_type": "stream",
|
302 |
+
"text": [
|
303 |
+
"tensor(0., device='cuda:0', grad_fn=<SumBackward1>)\n"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"name": "stderr",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"\n"
|
311 |
+
]
|
312 |
+
}
|
313 |
+
],
|
314 |
+
"source": [
|
315 |
+
"for sample_batch, sample_label in tqdm(train_loader):\n",
|
316 |
+
" i = sample_batch[0]\n",
|
317 |
+
" j = sample_batch[1]\n",
|
318 |
+
" output_i = model(i)\n",
|
319 |
+
" output_j = model(j)\n",
|
320 |
+
" sim_i_j = similarity(output_i, output_j)\n",
|
321 |
+
" if sim_i_j < min:\n",
|
322 |
+
" min = sim_i_j\n",
|
323 |
+
" \n",
|
324 |
+
"print(min)"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": 57,
|
330 |
+
"id": "a0d95f76",
|
331 |
+
"metadata": {},
|
332 |
+
"outputs": [
|
333 |
+
{
|
334 |
+
"name": "stderr",
|
335 |
+
"output_type": "stream",
|
336 |
+
"text": [
|
337 |
+
"100%|██████████████████████████████████████████████████████████████████████████████| 4351/4351 [01:04<00:00, 67.76it/s]"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"name": "stdout",
|
342 |
+
"output_type": "stream",
|
343 |
+
"text": [
|
344 |
+
"Accuracy: 0.2659109846852283\n"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"name": "stderr",
|
349 |
+
"output_type": "stream",
|
350 |
+
"text": [
|
351 |
+
"\n"
|
352 |
+
]
|
353 |
+
}
|
354 |
+
],
|
355 |
+
"source": [
|
356 |
+
"model.eval()\n",
|
357 |
+
"total_correct = 0\n",
|
358 |
+
"total_samples = 0\n",
|
359 |
+
"\n",
|
360 |
+
"with torch.no_grad():\n",
|
361 |
+
" for i, (sentences_batch, labels_batch) in enumerate(tqdm(train_loader)):\n",
|
362 |
+
" labels_batch = labels_batch.to(device)\n",
|
363 |
+
"# labels_batch = F.one_hot(labels_batch, num_classes = 102).float()\n",
|
364 |
+
" logits = model(sentences_batch).float()\n",
|
365 |
+
" predicted = torch.argmax(logits, dim = 1)\n",
|
366 |
+
" total_samples += labels_batch.size(0)\n",
|
367 |
+
" total_correct += (predicted == labels_batch).sum().item()\n",
|
368 |
+
"\n",
|
369 |
+
"accuracy = total_correct / total_samples\n",
|
370 |
+
"print(\"Accuracy:\", accuracy)"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"execution_count": 58,
|
376 |
+
"id": "da1763b7",
|
377 |
+
"metadata": {},
|
378 |
+
"outputs": [
|
379 |
+
{
|
380 |
+
"name": "stdout",
|
381 |
+
"output_type": "stream",
|
382 |
+
"text": [
|
383 |
+
"ENTER DESCRIPTION pink\n",
|
384 |
+
"tensor([9.6708e-15, 1.0179e-04, 4.2242e-08, 1.3063e-15, 8.8056e-03, 0.0000e+00,\n",
|
385 |
+
" 1.6553e-14, 8.9271e-33, 9.0644e-27, 5.9910e-19, 2.2721e-24, 7.7432e-03,\n",
|
386 |
+
" 3.9587e-36, 7.1618e-07, 2.7430e-08, 0.0000e+00, 0.0000e+00, 1.4562e-03,\n",
|
387 |
+
" 9.8114e-06, 9.2844e-24, 7.8520e-33, 2.9296e-22, 3.5067e-13, 1.3316e-05,\n",
|
388 |
+
" 7.7768e-11, 9.2201e-39, 5.0639e-22, 1.6904e-19, 3.2689e-35, 1.0034e-14,\n",
|
389 |
+
" 9.8686e-01, 4.1330e-05, 6.3048e-01, 9.5960e-23, 1.2662e-14, 2.4540e-22,\n",
|
390 |
+
" 1.4413e-08, 9.9928e-01, 2.8299e-02, 4.9763e-10, 2.7364e-04, 9.9878e-01,\n",
|
391 |
+
" 0.0000e+00, 9.9998e-01, 6.7328e-02, 2.9939e-13, 1.9145e-17, 0.0000e+00,\n",
|
392 |
+
" 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9998e-01, 1.1818e-30, 2.2513e-22,\n",
|
393 |
+
" 0.0000e+00, 1.0346e-32, 8.8656e-21, 9.9353e-01, 4.3037e-03, 8.6023e-39,\n",
|
394 |
+
" 3.6964e-10, 3.3164e-21, 1.9611e-15, 0.0000e+00, 3.7135e-38, 1.3163e-34,\n",
|
395 |
+
" 1.8906e-07, 7.0084e-30, 1.0882e-20, 2.6501e-33, 8.9597e-39, 5.0791e-37,\n",
|
396 |
+
" 1.0000e+00, 5.7929e-03, 1.3252e-03, 1.4498e-23, 1.3656e-02, 2.0226e-07,\n",
|
397 |
+
" 8.3005e-01, 8.4326e-14, 2.1941e-03, 3.8749e-28, 9.8803e-01, 9.9992e-01,\n",
|
398 |
+
" 4.3195e-11, 7.0360e-01, 1.0000e+00, 1.5408e-02, 9.9689e-01, 8.0569e-15,\n",
|
399 |
+
" 1.4282e-22, 9.6706e-03, 4.9712e-03, 4.8348e-05, 1.2486e-05, 9.9923e-01,\n",
|
400 |
+
" 6.3526e-06, 1.7522e-01, 8.8239e-01, 2.0713e-11, 2.2530e-20, 2.1032e-05],\n",
|
401 |
+
" device='cuda:0', grad_fn=<SigmoidBackward0>)\n",
|
402 |
+
"tensor(72, device='cuda:0')\n"
|
403 |
+
]
|
404 |
+
}
|
405 |
+
],
|
406 |
+
"source": [
|
407 |
+
"sentence = input(\"ENTER DESCRIPTION \")\n",
|
408 |
+
"output = model(sentence)\n",
|
409 |
+
"predicted = torch.argmax(output)\n",
|
410 |
+
"print(output)\n",
|
411 |
+
"print(predicted)"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"cell_type": "code",
|
416 |
+
"execution_count": 59,
|
417 |
+
"id": "bcf4856c",
|
418 |
+
"metadata": {},
|
419 |
+
"outputs": [],
|
420 |
+
"source": [
|
421 |
+
"torch.save(model.state_dict(), \"sentence_embedding.pth\")"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"cell_type": "code",
|
426 |
+
"execution_count": null,
|
427 |
+
"id": "5f6f22f4",
|
428 |
+
"metadata": {},
|
429 |
+
"outputs": [],
|
430 |
+
"source": []
|
431 |
+
}
|
432 |
+
],
|
433 |
+
"metadata": {
|
434 |
+
"kernelspec": {
|
435 |
+
"display_name": "Python 3 (ipykernel)",
|
436 |
+
"language": "python",
|
437 |
+
"name": "python3"
|
438 |
+
},
|
439 |
+
"language_info": {
|
440 |
+
"codemirror_mode": {
|
441 |
+
"name": "ipython",
|
442 |
+
"version": 3
|
443 |
+
},
|
444 |
+
"file_extension": ".py",
|
445 |
+
"mimetype": "text/x-python",
|
446 |
+
"name": "python",
|
447 |
+
"nbconvert_exporter": "python",
|
448 |
+
"pygments_lexer": "ipython3",
|
449 |
+
"version": "3.10.4"
|
450 |
+
}
|
451 |
+
},
|
452 |
+
"nbformat": 4,
|
453 |
+
"nbformat_minor": 5
|
454 |
+
}
|
Unconditional_Diffusion_64.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Untitled.ipynb
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "68fece49",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import gradio as gr\n",
|
11 |
+
"import torch\n",
|
12 |
+
"import torch.nn as nn\n",
|
13 |
+
"import torch.nn.functional as F\n",
|
14 |
+
"import matplotlib.pyplot as plt\n",
|
15 |
+
"\n",
|
16 |
+
"class DoubleConv(nn.Module):\n",
|
17 |
+
" def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):\n",
|
18 |
+
" super().__init__()\n",
|
19 |
+
" self.residual = residual\n",
|
20 |
+
" if not mid_channels:\n",
|
21 |
+
" mid_channels = out_channels\n",
|
22 |
+
" self.double_conv = nn.Sequential(\n",
|
23 |
+
" nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),\n",
|
24 |
+
" nn.GroupNorm(1, mid_channels),\n",
|
25 |
+
" nn.GELU(),\n",
|
26 |
+
" nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
|
27 |
+
" nn.GroupNorm(1, out_channels),\n",
|
28 |
+
" )\n",
|
29 |
+
"\n",
|
30 |
+
" def forward(self, x):\n",
|
31 |
+
" if self.residual:\n",
|
32 |
+
" return F.gelu(x + self.double_conv(x))\n",
|
33 |
+
" else:\n",
|
34 |
+
" return self.double_conv(x)\n",
|
35 |
+
"\n",
|
36 |
+
"class Down(nn.Module):\n",
|
37 |
+
" def __init__(self, in_channels, out_channels, emb_dim=256):\n",
|
38 |
+
" super().__init__()\n",
|
39 |
+
" self.maxpool_conv = nn.Sequential(\n",
|
40 |
+
" nn.MaxPool2d(2),\n",
|
41 |
+
" DoubleConv(in_channels, in_channels, residual=True),\n",
|
42 |
+
" DoubleConv(in_channels, out_channels),\n",
|
43 |
+
" )\n",
|
44 |
+
"\n",
|
45 |
+
" self.emb_layer = nn.Sequential(\n",
|
46 |
+
" nn.SiLU(),\n",
|
47 |
+
" nn.Linear(\n",
|
48 |
+
" emb_dim,\n",
|
49 |
+
" out_channels\n",
|
50 |
+
" ),\n",
|
51 |
+
" )\n",
|
52 |
+
"\n",
|
53 |
+
" def forward(self, x, t):\n",
|
54 |
+
" x = self.maxpool_conv(x)\n",
|
55 |
+
" emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])\n",
|
56 |
+
" return x + emb\n",
|
57 |
+
"\n",
|
58 |
+
"class Up(nn.Module):\n",
|
59 |
+
" def __init__(self, in_channels, out_channels, emb_dim=256):\n",
|
60 |
+
" super().__init__()\n",
|
61 |
+
"\n",
|
62 |
+
" self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=True)\n",
|
63 |
+
" self.conv = nn.Sequential(\n",
|
64 |
+
" DoubleConv(in_channels, in_channels, residual=True),\n",
|
65 |
+
" DoubleConv(in_channels, out_channels, in_channels // 2),\n",
|
66 |
+
" )\n",
|
67 |
+
"\n",
|
68 |
+
" self.emb_layer = nn.Sequential(\n",
|
69 |
+
" nn.SiLU(),\n",
|
70 |
+
" nn.Linear(\n",
|
71 |
+
" emb_dim,\n",
|
72 |
+
" out_channels\n",
|
73 |
+
" ),\n",
|
74 |
+
" )\n",
|
75 |
+
"\n",
|
76 |
+
" def forward(self, x, skip_x, t):\n",
|
77 |
+
" x = self.up(x)\n",
|
78 |
+
" x = torch.cat([skip_x, x], dim=1)\n",
|
79 |
+
" x = self.conv(x)\n",
|
80 |
+
" emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])\n",
|
81 |
+
" return x + emb\n",
|
82 |
+
"\n",
|
83 |
+
"class UNet(nn.Module):\n",
|
84 |
+
" def __init__(self, c_in=3, c_out=3, time_dim=256, device=\"cuda\"):\n",
|
85 |
+
" super().__init__()\n",
|
86 |
+
" self.device = device\n",
|
87 |
+
" self.time_dim = time_dim\n",
|
88 |
+
"\n",
|
89 |
+
" self.inc = DoubleConv(c_in, 64)\n",
|
90 |
+
" self.down1 = Down(64, 128)\n",
|
91 |
+
" self.down2 = Down(128, 256)\n",
|
92 |
+
" self.down3 = Down(256, 256)\n",
|
93 |
+
"\n",
|
94 |
+
" self.bot1 = DoubleConv(256, 512)\n",
|
95 |
+
" self.bot2 = DoubleConv(512, 512)\n",
|
96 |
+
" self.bot3 = DoubleConv(512, 256)\n",
|
97 |
+
"\n",
|
98 |
+
" self.up1 = Up(512, 128)\n",
|
99 |
+
" self.up2 = Up(256, 64)\n",
|
100 |
+
" self.up3 = Up(128, 64)\n",
|
101 |
+
" self.outc = nn.Conv2d(64, c_out, kernel_size=1)\n",
|
102 |
+
"\n",
|
103 |
+
" def positional_encoding(self, t, channels):\n",
|
104 |
+
" inv_freq = 1.0 / (\n",
|
105 |
+
" 10000\n",
|
106 |
+
" ** (torch.arange(0, channels, 2, device=self.device).float() / channels)\n",
|
107 |
+
" )\n",
|
108 |
+
" pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)\n",
|
109 |
+
" pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)\n",
|
110 |
+
" pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)\n",
|
111 |
+
" return pos_enc\n",
|
112 |
+
"\n",
|
113 |
+
" def forward(self, image, t):\n",
|
114 |
+
" t = t.unsqueeze(-1).type(torch.float)\n",
|
115 |
+
" t = self.positional_encoding(t, self.time_dim)\n",
|
116 |
+
"\n",
|
117 |
+
" x1 = self.inc(image)\n",
|
118 |
+
" x2 = self.down1(x1, t)\n",
|
119 |
+
" x3 = self.down2(x2, t)\n",
|
120 |
+
" x4 = self.down3(x3, t)\n",
|
121 |
+
"\n",
|
122 |
+
" x4 = self.bot1(x4)\n",
|
123 |
+
" # x4 = self.bot2(x4)\n",
|
124 |
+
" x4 = self.bot3(x4)\n",
|
125 |
+
"\n",
|
126 |
+
" x = self.up1(x4, x3, t)\n",
|
127 |
+
" x = self.up2(x, x2, t)\n",
|
128 |
+
" x = self.up3(x, x1, t)\n",
|
129 |
+
" output = self.outc(x)\n",
|
130 |
+
" return output\n",
|
131 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
132 |
+
"model = UNet(device = device).to(device)\n",
|
133 |
+
"model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth'))\n",
|
134 |
+
"img_size = 64\n",
|
135 |
+
"class Diffusion():\n",
|
136 |
+
" def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device):\n",
|
137 |
+
" self.time_steps = time_steps\n",
|
138 |
+
" self.beta_start = beta_start\n",
|
139 |
+
" self.beta_stop = beta_stop\n",
|
140 |
+
" self.img_size = image_size\n",
|
141 |
+
" self.device = device\n",
|
142 |
+
"\n",
|
143 |
+
" self.beta = self.beta_schedule()\n",
|
144 |
+
" self.beta = self.beta.to(device)\n",
|
145 |
+
" self.alpha = 1 - self.beta\n",
|
146 |
+
" self.alpha = self.alpha.to(device)\n",
|
147 |
+
" self.alpha_hat = torch.cumprod(self.alpha, dim = 0).to(device)\n",
|
148 |
+
"\n",
|
149 |
+
"\n",
|
150 |
+
" def beta_schedule(self):\n",
|
151 |
+
" return torch.linspace(self.beta_start, self.beta_stop, self.time_steps)\n",
|
152 |
+
"\n",
|
153 |
+
" def noise_images(self, images, t):\n",
|
154 |
+
" sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None,]\n",
|
155 |
+
" sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None,]\n",
|
156 |
+
" noises = torch.randn_like(images)\n",
|
157 |
+
" noised_images = sqrt_alpha_hat * images + sqrt_one_minus_alpha_hat * noises\n",
|
158 |
+
" return noised_images, noises\n",
|
159 |
+
"\n",
|
160 |
+
" def random_timesteps(self, n):\n",
|
161 |
+
" return torch.randint(low=1, high=self.time_steps, size=(n,))\n",
|
162 |
+
"\n",
|
163 |
+
" def generate_samples(self, model, n):\n",
|
164 |
+
" with torch.no_grad():\n",
|
165 |
+
" x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)\n",
|
166 |
+
" for i in range(self.time_steps - 1, 1, -1):\n",
|
167 |
+
" t = (torch.ones(n) * i).long().to(self.device)\n",
|
168 |
+
" predicted_noise = model(x, t)\n",
|
169 |
+
" alpha = self.alpha[t][:, None, None, None]\n",
|
170 |
+
" alpha_hat = self.alpha_hat[t][:, None, None, None]\n",
|
171 |
+
" beta = self.beta[t][:, None, None, None]\n",
|
172 |
+
" if i > 1:\n",
|
173 |
+
" noise = torch.randn_like(x)\n",
|
174 |
+
" else:\n",
|
175 |
+
" noise = torch.zeros_like(x)\n",
|
176 |
+
" x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise\n",
|
177 |
+
"\n",
|
178 |
+
" return (x[0].cpu().numpy().transpose(1, 2, 0) / 255)\n",
|
179 |
+
" #show_images\n",
|
180 |
+
"\n",
|
181 |
+
"diffusion = Diffusion()\n"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": 26,
|
187 |
+
"id": "a80516cd",
|
188 |
+
"metadata": {},
|
189 |
+
"outputs": [
|
190 |
+
{
|
191 |
+
"name": "stdout",
|
192 |
+
"output_type": "stream",
|
193 |
+
"text": [
|
194 |
+
"Running on local URL: http://127.0.0.1:7867\n",
|
195 |
+
"Running on public URL: https://080248f8c7c14eec1e.gradio.live\n",
|
196 |
+
"\n",
|
197 |
+
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"data": {
|
202 |
+
"text/html": [
|
203 |
+
"<div><iframe src=\"https://080248f8c7c14eec1e.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
204 |
+
],
|
205 |
+
"text/plain": [
|
206 |
+
"<IPython.core.display.HTML object>"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
"metadata": {},
|
210 |
+
"output_type": "display_data"
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"data": {
|
214 |
+
"text/plain": []
|
215 |
+
},
|
216 |
+
"execution_count": 26,
|
217 |
+
"metadata": {},
|
218 |
+
"output_type": "execute_result"
|
219 |
+
}
|
220 |
+
],
|
221 |
+
"source": [
|
222 |
+
"import numpy as np\n",
|
223 |
+
"def greet(n):\n",
|
224 |
+
" image = diffusion.generate_samples(model, n = 1)\n",
|
225 |
+
" image = (np.clip(image * 255, -1, 1) + 1) / 2\n",
|
226 |
+
" plt.imshow(image)\n",
|
227 |
+
" return image\n",
|
228 |
+
"\n",
|
229 |
+
"iface = gr.Interface(fn=greet, inputs=\"number\", outputs=\"image\")\n",
|
230 |
+
"iface.launch(share = True)"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": null,
|
236 |
+
"id": "cc6f5064",
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [],
|
239 |
+
"source": []
|
240 |
+
}
|
241 |
+
],
|
242 |
+
"metadata": {
|
243 |
+
"kernelspec": {
|
244 |
+
"display_name": "Python 3 (ipykernel)",
|
245 |
+
"language": "python",
|
246 |
+
"name": "python3"
|
247 |
+
},
|
248 |
+
"language_info": {
|
249 |
+
"codemirror_mode": {
|
250 |
+
"name": "ipython",
|
251 |
+
"version": 3
|
252 |
+
},
|
253 |
+
"file_extension": ".py",
|
254 |
+
"mimetype": "text/x-python",
|
255 |
+
"name": "python",
|
256 |
+
"nbconvert_exporter": "python",
|
257 |
+
"pygments_lexer": "ipython3",
|
258 |
+
"version": "3.10.4"
|
259 |
+
}
|
260 |
+
},
|
261 |
+
"nbformat": 4,
|
262 |
+
"nbformat_minor": 5
|
263 |
+
}
|
VAE_64.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
class DoubleConv(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
|
9 |
+
super().__init__()
|
10 |
+
self.residual = residual
|
11 |
+
if not mid_channels:
|
12 |
+
mid_channels = out_channels
|
13 |
+
self.double_conv = nn.Sequential(
|
14 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
15 |
+
nn.GroupNorm(1, mid_channels),
|
16 |
+
nn.GELU(),
|
17 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
18 |
+
nn.GroupNorm(1, out_channels),
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
if self.residual:
|
23 |
+
return F.gelu(x + self.double_conv(x))
|
24 |
+
else:
|
25 |
+
return self.double_conv(x)
|
26 |
+
|
27 |
+
class Down(nn.Module):
|
28 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
29 |
+
super().__init__()
|
30 |
+
self.maxpool_conv = nn.Sequential(
|
31 |
+
nn.MaxPool2d(2),
|
32 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
33 |
+
DoubleConv(in_channels, out_channels),
|
34 |
+
)
|
35 |
+
|
36 |
+
self.emb_layer = nn.Sequential(
|
37 |
+
nn.SiLU(),
|
38 |
+
nn.Linear(
|
39 |
+
emb_dim,
|
40 |
+
out_channels
|
41 |
+
),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x, t):
|
45 |
+
x = self.maxpool_conv(x)
|
46 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
|
47 |
+
return x + emb
|
48 |
+
|
49 |
+
class Up(nn.Module):
|
50 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
54 |
+
self.conv = nn.Sequential(
|
55 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
56 |
+
DoubleConv(in_channels, out_channels, in_channels // 2),
|
57 |
+
)
|
58 |
+
|
59 |
+
self.emb_layer = nn.Sequential(
|
60 |
+
nn.SiLU(),
|
61 |
+
nn.Linear(
|
62 |
+
emb_dim,
|
63 |
+
out_channels
|
64 |
+
),
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, x, skip_x, t):
|
68 |
+
x = self.up(x)
|
69 |
+
x = torch.cat([skip_x, x], dim=1)
|
70 |
+
x = self.conv(x)
|
71 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
|
72 |
+
return x + emb
|
73 |
+
|
74 |
+
class UNet(nn.Module):
|
75 |
+
def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
|
76 |
+
super().__init__()
|
77 |
+
self.device = device
|
78 |
+
self.time_dim = time_dim
|
79 |
+
|
80 |
+
self.inc = DoubleConv(c_in, 64)
|
81 |
+
self.down1 = Down(64, 128)
|
82 |
+
self.down2 = Down(128, 256)
|
83 |
+
self.down3 = Down(256, 256)
|
84 |
+
|
85 |
+
self.bot1 = DoubleConv(256, 512)
|
86 |
+
self.bot2 = DoubleConv(512, 512)
|
87 |
+
self.bot3 = DoubleConv(512, 256)
|
88 |
+
|
89 |
+
self.up1 = Up(512, 128)
|
90 |
+
self.up2 = Up(256, 64)
|
91 |
+
self.up3 = Up(128, 64)
|
92 |
+
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
|
93 |
+
|
94 |
+
def positional_encoding(self, t, channels):
|
95 |
+
inv_freq = 1.0 / (
|
96 |
+
10000
|
97 |
+
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
|
98 |
+
)
|
99 |
+
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
|
100 |
+
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
|
101 |
+
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
|
102 |
+
return pos_enc
|
103 |
+
|
104 |
+
def forward(self, image, t):
|
105 |
+
t = t.unsqueeze(-1).type(torch.float)
|
106 |
+
t = self.positional_encoding(t, self.time_dim)
|
107 |
+
|
108 |
+
x1 = self.inc(image)
|
109 |
+
x2 = self.down1(x1, t)
|
110 |
+
x3 = self.down2(x2, t)
|
111 |
+
x4 = self.down3(x3, t)
|
112 |
+
|
113 |
+
x4 = self.bot1(x4)
|
114 |
+
# x4 = self.bot2(x4)
|
115 |
+
x4 = self.bot3(x4)
|
116 |
+
|
117 |
+
x = self.up1(x4, x3, t)
|
118 |
+
x = self.up2(x, x2, t)
|
119 |
+
x = self.up3(x, x1, t)
|
120 |
+
output = self.outc(x)
|
121 |
+
return output
|
122 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
123 |
+
model = UNet(device = device).to(device)
|
124 |
+
model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth'))
|
125 |
+
img_size = 64
|
126 |
+
class Diffusion():
|
127 |
+
def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device):
|
128 |
+
self.time_steps = time_steps
|
129 |
+
self.beta_start = beta_start
|
130 |
+
self.beta_stop = beta_stop
|
131 |
+
self.img_size = image_size
|
132 |
+
self.device = device
|
133 |
+
|
134 |
+
self.beta = self.beta_schedule()
|
135 |
+
self.beta = self.beta.to(device)
|
136 |
+
self.alpha = 1 - self.beta
|
137 |
+
self.alpha = self.alpha.to(device)
|
138 |
+
self.alpha_hat = torch.cumprod(self.alpha, dim = 0).to(device)
|
139 |
+
|
140 |
+
|
141 |
+
def beta_schedule(self):
|
142 |
+
return torch.linspace(self.beta_start, self.beta_stop, self.time_steps)
|
143 |
+
|
144 |
+
def noise_images(self, images, t):
|
145 |
+
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None,]
|
146 |
+
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None,]
|
147 |
+
noises = torch.randn_like(images)
|
148 |
+
noised_images = sqrt_alpha_hat * images + sqrt_one_minus_alpha_hat * noises
|
149 |
+
return noised_images, noises
|
150 |
+
|
151 |
+
def random_timesteps(self, n):
|
152 |
+
return torch.randint(low=1, high=self.time_steps, size=(n,))
|
153 |
+
|
154 |
+
def generate_samples(self, model, n):
|
155 |
+
with torch.no_grad():
|
156 |
+
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
|
157 |
+
for i in range(self.time_steps - 1, 1, -1):
|
158 |
+
t = (torch.ones(n) * i).long().to(self.device)
|
159 |
+
predicted_noise = model(x, t)
|
160 |
+
alpha = self.alpha[t][:, None, None, None]
|
161 |
+
alpha_hat = self.alpha_hat[t][:, None, None, None]
|
162 |
+
beta = self.beta[t][:, None, None, None]
|
163 |
+
if i > 1:
|
164 |
+
noise = torch.randn_like(x)
|
165 |
+
else:
|
166 |
+
noise = torch.zeros_like(x)
|
167 |
+
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
|
168 |
+
|
169 |
+
return (x[0].cpu().numpy().transpose(1, 2, 0) / 255)
|
170 |
+
#show_images
|
171 |
+
|
172 |
+
diffusion = Diffusion()
|
173 |
+
|
174 |
+
import numpy as np
|
175 |
+
def greet(n):
|
176 |
+
image = diffusion.generate_samples(model, n = 1)
|
177 |
+
image = (np.clip(image * 255, -1, 1) + 1) / 2
|
178 |
+
plt.imshow(image)
|
179 |
+
return image
|
180 |
+
|
181 |
+
iface = gr.Interface(fn=greet, inputs="number", outputs="image")
|
182 |
+
iface.launch(share = True)
|
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
n,output,flag,username,timestamp
|
2 |
+
6,,,,2023-07-25 15:54:19.847888
|
image_desc.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
jpg/flowers/image_00001.jpg
ADDED
jpg/flowers/image_00002.jpg
ADDED
jpg/flowers/image_00003.jpg
ADDED
jpg/flowers/image_00004.jpg
ADDED
jpg/flowers/image_00005.jpg
ADDED
jpg/flowers/image_00006.jpg
ADDED
jpg/flowers/image_00007.jpg
ADDED
jpg/flowers/image_00008.jpg
ADDED
jpg/flowers/image_00009.jpg
ADDED
jpg/flowers/image_00010.jpg
ADDED
jpg/flowers/image_00011.jpg
ADDED
jpg/flowers/image_00012.jpg
ADDED
jpg/flowers/image_00013.jpg
ADDED
jpg/flowers/image_00014.jpg
ADDED
jpg/flowers/image_00015.jpg
ADDED
jpg/flowers/image_00016.jpg
ADDED
jpg/flowers/image_00017.jpg
ADDED
jpg/flowers/image_00018.jpg
ADDED
jpg/flowers/image_00019.jpg
ADDED
jpg/flowers/image_00020.jpg
ADDED
jpg/flowers/image_00021.jpg
ADDED
jpg/flowers/image_00022.jpg
ADDED
jpg/flowers/image_00023.jpg
ADDED
jpg/flowers/image_00024.jpg
ADDED
jpg/flowers/image_00025.jpg
ADDED
jpg/flowers/image_00026.jpg
ADDED