lovelyai999 commited on
Commit
55c9f00
·
verified ·
1 Parent(s): 53c5b5f

Upload 2 files

Browse files
Files changed (2) hide show
  1. imageAI.py +319 -0
  2. myImage.py +34 -0
imageAI.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import google.colab
3
+ IN_COLAB = True
4
+ from google.colab import drive,files
5
+ from google.colab import output
6
+ drive.mount('/gdrive')
7
+ Gbase="/gdrive/MyDrive/generate/"
8
+ cache_dir="/gdrive/MyDrive/hf/"
9
+ import sys
10
+ sys.path.append(Gbase)
11
+ except:
12
+ IN_COLAB = False
13
+ Gbase="./"
14
+ cache_dir="./hf/"
15
+
16
+
17
+ import cv2,os
18
+ import numpy as np
19
+ import random,string
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ from torch.utils.data import Dataset, DataLoader
25
+
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Using device: {device}")
28
+
29
+ IMAGE_SIZE = 64
30
+ NUM_SAMPLES = 1000
31
+ BATCH_SIZE = 4
32
+ EPOCHS = 500
33
+ LEARNING_RATE = 0.001
34
+
35
+
36
+ def generate_sample(num_shapes=1):
37
+ image = np.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=np.uint8)
38
+ instructions = []
39
+
40
+ #num_shapes = random.randint(1, 3)
41
+ for _ in range(num_shapes):
42
+ shape = random.choice(['line', 'rectangle', 'circle', 'ellipse', 'polygon'])
43
+ color = random.randint(0, 255)
44
+ thickness = random.randint(1, 3)
45
+
46
+ if shape == 'line':
47
+ start_point = (random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE))
48
+ end_point = (random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE))
49
+ cv2.line(image, start_point, end_point, color, thickness)
50
+ instructions.append(f"cv2.line(image, {start_point}, {end_point}, {color}, {thickness})")
51
+
52
+ elif shape == 'rectangle':
53
+ start_point = (random.randint(0, IMAGE_SIZE - 10), random.randint(0, IMAGE_SIZE - 10))
54
+ end_point = (start_point[0] + random.randint(10, IMAGE_SIZE - start_point[0]),
55
+ start_point[1] + random.randint(10, IMAGE_SIZE - start_point[1]))
56
+ cv2.rectangle(image, start_point, end_point, color, thickness)
57
+ instructions.append(f"cv2.rectangle(image, {start_point}, {end_point}, {color}, {thickness})")
58
+
59
+ elif shape == 'circle':
60
+ center = (random.randint(10, IMAGE_SIZE - 10), random.randint(10, IMAGE_SIZE - 10))
61
+ radius = random.randint(5, min(center[0], center[1], IMAGE_SIZE - center[0], IMAGE_SIZE - center[1]))
62
+ cv2.circle(image, center, radius, color, thickness)
63
+ instructions.append(f"cv2.circle(image, {center}, {radius}, {color}, {thickness})")
64
+
65
+ elif shape == 'ellipse':
66
+ center = (random.randint(10, IMAGE_SIZE - 10), random.randint(10, IMAGE_SIZE - 10))
67
+ axes = (random.randint(5, 30), random.randint(5, 30))
68
+ angle = random.randint(0, 360)
69
+ cv2.ellipse(image, center, axes, angle, 0, 360, color, thickness)
70
+ instructions.append(f"cv2.ellipse(image, {center}, {axes}, {angle}, 0, 360, {color}, {thickness})")
71
+
72
+ elif shape == 'polygon':
73
+ num_points = random.randint(3, 6)
74
+ points = np.array([(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)], np.int32)
75
+ points = points.reshape((-1, 1, 2))
76
+ cv2.polylines(image, [points], True, color, thickness)
77
+ instructions.append(f"cv2.polylines(image, [{points.tolist()}], True, {color}, {thickness})")
78
+
79
+ return {'image': image, 'instructions': instructions}
80
+
81
+ def generate_dataset(NUM_SAMPLES=NUM_SAMPLES,maxNumShape=3):
82
+ dataset = []
83
+ for _ in range(NUM_SAMPLES):
84
+ num_shapes = random.randint(1, maxNumShape)
85
+ sample = generate_sample(num_shapes=num_shapes)
86
+ dataset.append(sample)
87
+ return dataset
88
+
89
+ class ImageDataset(Dataset):
90
+ def __init__(self, dataset):
91
+ self.dataset = dataset
92
+
93
+ def __len__(self):
94
+ return len(self.dataset)
95
+
96
+ def __getitem__(self, idx):
97
+ sample = self.dataset[idx]
98
+ image = torch.FloatTensor(sample['image']).unsqueeze(0) / 255.0
99
+ return image, len(sample['instructions'])
100
+
101
+ class SimpleModel(nn.Module):
102
+ def __init__(self, path=None):
103
+ super(SimpleModel, self).__init__()
104
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
105
+ self.bn1 = nn.BatchNorm2d(32)
106
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
107
+ self.bn2 = nn.BatchNorm2d(64)
108
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
109
+ self.bn3 = nn.BatchNorm2d(128)
110
+ self.pool = nn.MaxPool2d(2, 2)
111
+ self.fc1 = nn.Linear(128 * 8 * 8, 512)
112
+ self.fc2 = nn.Linear(512, 128)
113
+ self.fc3 = nn.Linear(128, 1)
114
+ self.dropout = nn.Dropout(0.5)
115
+
116
+ if path and os.path.exists(path):
117
+ self.load_state_dict(torch.load(path, map_location=device))
118
+
119
+ def forward(self, x):
120
+ x = self.pool(F.leaky_relu(self.bn1(self.conv1(x))))
121
+ x = self.pool(F.leaky_relu(self.bn2(self.conv2(x))))
122
+ x = self.pool(F.leaky_relu(self.bn3(self.conv3(x))))
123
+ x = x.view(-1, 128 * 8 * 8)
124
+ x = F.leaky_relu(self.fc1(x))
125
+ x = self.dropout(x)
126
+ x = F.leaky_relu(self.fc2(x))
127
+ x = self.dropout(x)
128
+ x = self.fc3(x)
129
+ return x
130
+
131
+ def predict(self, image):
132
+ self.eval()
133
+ with torch.no_grad():
134
+ if isinstance(image, str) and os.path.isfile(image):
135
+ # 如果輸入是圖片文件路徑
136
+ img = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
137
+ img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
138
+ elif isinstance(image, np.ndarray):
139
+ # 如果輸入是 numpy 數組
140
+ if image.ndim == 3:
141
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
142
+ else:
143
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
144
+ img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
145
+ else:
146
+ raise ValueError("Input should be an image file path or a numpy array")
147
+
148
+ img_tensor = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0) / 255.0
149
+ img_tensor = img_tensor.to(device)
150
+ output = self(img_tensor).item()
151
+
152
+ # 將輸出四捨五入到最接近的整數
153
+ num_instructions = round(output)
154
+
155
+ # 生成相應數量的繪圖指令
156
+ instructions = []
157
+ for _ in range(num_instructions):
158
+ shape = random.choice(['line', 'rectangle', 'circle', 'ellipse', 'polygon'])
159
+ if shape == 'line':
160
+ instructions.append(f"cv2.line(image, {(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE))}, {(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE))}, {random.randint(0, 255)}, {random.randint(1, 3)})")
161
+ elif shape == 'rectangle':
162
+ instructions.append(f"cv2.rectangle(image, {(random.randint(0, IMAGE_SIZE-10), random.randint(0, IMAGE_SIZE-10))}, {(random.randint(10, IMAGE_SIZE), random.randint(10, IMAGE_SIZE))}, {random.randint(0, 255)}, {random.randint(1, 3)})")
163
+ elif shape == 'circle':
164
+ instructions.append(f"cv2.circle(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {random.randint(5, 30)}, {random.randint(0, 255)}, {random.randint(1, 3)})")
165
+ elif shape == 'ellipse':
166
+ instructions.append(f"cv2.ellipse(image, {(random.randint(10, IMAGE_SIZE-10), random.randint(10, IMAGE_SIZE-10))}, {(random.randint(5, 30), random.randint(5, 30))}, {random.randint(0, 360)}, 0, 360, {random.randint(0, 255)}, {random.randint(1, 3)})")
167
+ elif shape == 'polygon':
168
+ num_points = random.randint(3, 6)
169
+ points = [(random.randint(0, IMAGE_SIZE), random.randint(0, IMAGE_SIZE)) for _ in range(num_points)]
170
+ instructions.append(f"cv2.polylines(image, [np.array({points})], True, {random.randint(0, 255)}, {random.randint(1, 3)})")
171
+
172
+
173
+ return instructions
174
+
175
+ def train(model, train_loader, optimizer, criterion):
176
+ model.train()
177
+ total_loss = 0
178
+ for batch_idx, (data, target) in enumerate(train_loader):
179
+ data, target = data.to(device), target.float().to(device)
180
+ optimizer.zero_grad()
181
+ output = model(data).squeeze()
182
+ loss = criterion(output, target)
183
+ loss.backward()
184
+ optimizer.step()
185
+ total_loss += loss.item()
186
+ if batch_idx % 100 == 0:
187
+ print(f'Train Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.6f}')
188
+ return total_loss / len(train_loader)
189
+
190
+ def test(model, test_loader, criterion, print_predictions=False):
191
+ model.eval()
192
+ test_loss = 0
193
+ all_predictions = []
194
+ all_targets = []
195
+ with torch.no_grad():
196
+ for data, target in test_loader:
197
+ data, target = data.to(device), target.float().to(device)
198
+ output = model(data).squeeze()
199
+ test_loss += criterion(output, target).item()
200
+ all_predictions.extend(output.cpu().numpy())
201
+ all_targets.extend(target.cpu().numpy())
202
+
203
+ test_loss /= len(test_loader)
204
+ print(f'Test set: Average loss: {test_loss:.4f}')
205
+
206
+ if print_predictions:
207
+ print("Sample predictions:")
208
+ for pred, targ in zip(all_predictions[:10], all_targets[:10]):
209
+ print(f"Prediction: {pred:.2f}, Target: {targ:.2f}")
210
+
211
+ return test_loss, all_predictions, all_targets
212
+
213
+ def train1(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=1, EPOCHS=EPOCHS):
214
+ model = SimpleModel(path=os.path.join(Gbase, 'best_model.pth')).to(device)
215
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
216
+
217
+ optimizer_path = os.path.join(Gbase, 'optimizer.pth')
218
+ if os.path.exists(optimizer_path):
219
+ print("Loading optimizer state...")
220
+ optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
221
+
222
+ criterion = nn.MSELoss()
223
+
224
+ seed = 618 * 382 * 33
225
+ random.seed(seed)
226
+ np.random.seed(seed)
227
+ torch.manual_seed(seed)
228
+ if torch.cuda.is_available():
229
+ torch.cuda.manual_seed(seed)
230
+
231
+ dataset = generate_dataset(NUM_SAMPLES=NUM_SAMPLES, maxNumShape=maxNumShape)
232
+ train_size = int(0.8 * len(dataset))
233
+ train_dataset = ImageDataset(dataset[:train_size])
234
+ test_dataset = ImageDataset(dataset[train_size:])
235
+
236
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
237
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
238
+
239
+ best_loss = float('inf')
240
+
241
+ for epoch in range(EPOCHS):
242
+ print(f'Epoch {epoch+1}/{EPOCHS}')
243
+ train_loss = train(model, train_loader, optimizer, criterion)
244
+ test_loss, predictions, targets = test(model, test_loader, criterion, print_predictions=True)
245
+
246
+ if test_loss < best_loss:
247
+ best_loss = test_loss
248
+ torch.save(model.state_dict(), os.path.join(Gbase, 'best_model.pth'))
249
+ torch.save(optimizer.state_dict(), os.path.join(Gbase, 'optimizer.pth'))
250
+ print(f"New best model saved with test loss: {best_loss:.4f}")
251
+
252
+
253
+
254
+
255
+ def main():
256
+ # Set random seed
257
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
258
+ model = SimpleModel(path=Gbase+ 'best_model.pth').to(device)
259
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
260
+ if os.path.exists(Gbase+'optimizer.pth'):
261
+ print("Loading optimizer state...")
262
+ optimizer.load_state_dict(torch.load('optimizer.pth'))
263
+ criterion = nn.MSELoss()
264
+ test_image =Gbase+"image.jpg"
265
+ # np.random.randint(0, 256, (IMAGE_SIZE, IMAGE_SIZE), dtype=np.uint8)
266
+ instructions = model.predict(test_image)
267
+ print("Generated instructions:")
268
+ for instruction in instructions:
269
+ print(instruction)
270
+ # 檢查是否存在已保存的優化器狀態
271
+
272
+ #return
273
+ seed = 618 * 382 * 33
274
+ random.seed(seed)
275
+ np.random.seed(seed)
276
+ torch.manual_seed(seed)
277
+
278
+ # Generate dataset
279
+ dataset = generate_dataset()
280
+
281
+ # Split dataset into train and test
282
+ train_size = int(0.8 * len(dataset))
283
+ train_dataset = ImageDataset(dataset[:train_size])
284
+ test_dataset = ImageDataset(dataset[train_size:])
285
+
286
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
287
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
288
+
289
+
290
+
291
+ best_loss = float('inf')
292
+
293
+ for epoch in range(EPOCHS):
294
+ print(f'Epoch {epoch+1}/{EPOCHS}')
295
+ train_loss = train(model, train_loader, optimizer, criterion, device)
296
+ test_loss, predictions, targets = test(model, test_loader, criterion, device, print_predictions=True)
297
+
298
+ if test_loss < best_loss:
299
+ best_loss = test_loss
300
+ torch.save(model.state_dict(),Gbase+ 'best_model.pth')
301
+ torch.save(optimizer.state_dict(),Gbase+ 'optimizer.pth')
302
+ print(f"New best model saved with test loss: {best_loss:.4f}")
303
+
304
+ # 測試 predict 方法
305
+
306
+
307
+ if __name__ == "__main__":
308
+ train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=100)
309
+ train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=100)
310
+ train1(NUM_SAMPLES=1000 ,maxNumShape=1, EPOCHS=100)
311
+ train1(NUM_SAMPLES=10000 ,maxNumShape=2, EPOCHS=10)
312
+ train1(NUM_SAMPLES=10000 ,maxNumShape=3, EPOCHS=10)
313
+ train1(NUM_SAMPLES=100000 ,maxNumShape=5, EPOCHS=10)
314
+ train1(NUM_SAMPLES=100000 ,maxNumShape=5, EPOCHS=10)
315
+ train1(NUM_SAMPLES=100000 ,maxNumShape=5, EPOCHS=10)
316
+ while True:
317
+ train1(NUM_SAMPLES=100000 ,maxNumShape=8, EPOCHS=10)
318
+
319
+
myImage.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os,cv2
3
+ import numpy as np
4
+ def listImages(d):
5
+ images = []
6
+ for f in os.scandir(d):
7
+ if f.is_file() and f.name.split(".")[-1].lower() in (
8
+ "jpg",
9
+ "jpeg", # 添加 "jpeg" 格式
10
+ "png",
11
+ "bmp",
12
+ "svg",
13
+ "webp",
14
+ ):
15
+ images.append(f.path)
16
+ return images
17
+
18
+ def ImageToCV(img):
19
+ return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
20
+
21
+ def ImageFromBytes (content):
22
+ return Image.open(BytesIO(content))
23
+
24
+ def ImageToBytes (img,format="JPEG"):
25
+ return img.save(BytesIO(), format=format).getvalue()
26
+
27
+ def CVtoImage(img):
28
+ return Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
29
+
30
+ def CVfromBytes(img_bytes):
31
+ return cv2.imdecode(np.frombuffer(img_bytes, dtype=np.uint8) , 1)
32
+
33
+ def CVtoBytes (img,format=".jpg"):
34
+ return cv2.imencode(format,img)[1].tobytes()