TuringsSolutions
commited on
Commit
•
ae98e87
1
Parent(s):
c5a6c3c
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,9 @@
|
|
1 |
import os
|
2 |
-
import spaces
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
-
|
6 |
-
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
|
7 |
-
from keras.models import Model
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
-
import logging
|
10 |
-
from skimage.transform import resize
|
11 |
-
from PIL import Image, ImageEnhance, ImageFilter
|
12 |
from tqdm import tqdm
|
|
|
13 |
|
14 |
# Disable GPU usage by default
|
15 |
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
@@ -42,7 +36,6 @@ class SwarmNeuralNetwork:
|
|
42 |
self.image_shape = image_shape
|
43 |
self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
|
44 |
self.target_image = self.load_target_image(target_image_path)
|
45 |
-
self.mobilenet = self.load_mobilenet_model()
|
46 |
|
47 |
def random_position(self):
|
48 |
return np.random.randn(*self.image_shape)
|
@@ -51,13 +44,9 @@ class SwarmNeuralNetwork:
|
|
51 |
return np.random.randn(*self.image_shape) * 0.01
|
52 |
|
53 |
def load_target_image(self, img_path):
|
54 |
-
img = Image.open(img_path).resize((self.image_shape[1], self.image_shape[0]))
|
55 |
return np.array(img) / 127.5 - 1
|
56 |
|
57 |
-
def load_mobilenet_model(self):
|
58 |
-
mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
|
59 |
-
return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output)
|
60 |
-
|
61 |
def update_agents(self, timestep):
|
62 |
for agent in self.agents:
|
63 |
# Convert agent's position and target image into HDC space
|
@@ -82,23 +71,14 @@ class SwarmNeuralNetwork:
|
|
82 |
def train(self, epochs):
|
83 |
for epoch in tqdm(range(epochs), desc="Training"):
|
84 |
self.update_agents(epoch)
|
85 |
-
generated_image = self.generate_image()
|
86 |
-
|
87 |
-
# Display the generated image
|
88 |
-
self.display_image(generated_image, title=f'Epoch {epoch}')
|
89 |
|
90 |
-
|
91 |
-
plt.imshow(image)
|
92 |
-
plt.title(title)
|
93 |
-
plt.axis('off')
|
94 |
-
plt.show()
|
95 |
|
96 |
# Gradio Interface
|
97 |
def train_snn(image_path, num_agents, epochs):
|
98 |
snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(128, 128, 3), target_image_path=image_path)
|
99 |
-
snn.train(epochs=epochs)
|
100 |
-
generated_image
|
101 |
-
return generated_image
|
102 |
|
103 |
interface = gr.Interface(
|
104 |
fn=train_snn,
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from tqdm import tqdm
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
|
8 |
# Disable GPU usage by default
|
9 |
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
|
|
36 |
self.image_shape = image_shape
|
37 |
self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
|
38 |
self.target_image = self.load_target_image(target_image_path)
|
|
|
39 |
|
40 |
def random_position(self):
|
41 |
return np.random.randn(*self.image_shape)
|
|
|
44 |
return np.random.randn(*self.image_shape) * 0.01
|
45 |
|
46 |
def load_target_image(self, img_path):
|
47 |
+
img = Image.open(img_path).convert('RGB').resize((self.image_shape[1], self.image_shape[0]))
|
48 |
return np.array(img) / 127.5 - 1
|
49 |
|
|
|
|
|
|
|
|
|
50 |
def update_agents(self, timestep):
|
51 |
for agent in self.agents:
|
52 |
# Convert agent's position and target image into HDC space
|
|
|
71 |
def train(self, epochs):
|
72 |
for epoch in tqdm(range(epochs), desc="Training"):
|
73 |
self.update_agents(epoch)
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
return self.generate_image()
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# Gradio Interface
|
78 |
def train_snn(image_path, num_agents, epochs):
|
79 |
snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(128, 128, 3), target_image_path=image_path)
|
80 |
+
generated_image = snn.train(epochs=epochs)
|
81 |
+
return (generated_image * 255).astype(np.uint8)
|
|
|
82 |
|
83 |
interface = gr.Interface(
|
84 |
fn=train_snn,
|