|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import tensorflow as tf |
|
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input |
|
from keras.models import Model |
|
import matplotlib.pyplot as plt |
|
import logging |
|
from skimage.transform import resize |
|
from PIL import Image, ImageEnhance, ImageFilter |
|
from tqdm import tqdm |
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '' |
|
|
|
class SwarmAgent: |
|
def __init__(self, position, velocity): |
|
self.position = position |
|
self.velocity = velocity |
|
self.m = np.zeros_like(position) |
|
self.v = np.zeros_like(position) |
|
|
|
class SwarmNeuralNetwork: |
|
def __init__(self, num_agents, image_shape, target_image_path): |
|
self.image_shape = image_shape |
|
self.resized_shape = (256, 256, 3) |
|
self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)] |
|
self.target_image = self.load_target_image(target_image_path) |
|
self.generated_image = np.random.randn(*image_shape) |
|
self.mobilenet = self.load_mobilenet_model() |
|
self.current_epoch = 0 |
|
self.noise_schedule = np.linspace(0.1, 0.002, 1000) |
|
|
|
def random_position(self): |
|
return np.random.randn(*self.image_shape) |
|
|
|
def random_velocity(self): |
|
return np.random.randn(*self.image_shape) * 0.01 |
|
|
|
def load_target_image(self, img_path): |
|
img = Image.open(img_path) |
|
img = img.resize((self.image_shape[1], self.image_shape[0])) |
|
img_array = np.array(img) / 127.5 - 1 |
|
plt.imshow((img_array + 1) / 2) |
|
plt.title('Target Image') |
|
plt.show() |
|
return img_array |
|
|
|
def resize_image(self, image): |
|
return resize(image, self.resized_shape, anti_aliasing=True) |
|
|
|
def load_mobilenet_model(self): |
|
mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape) |
|
return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output) |
|
|
|
def add_positional_encoding(self, image): |
|
h, w, c = image.shape |
|
pos_enc = np.zeros_like(image) |
|
for i in range(h): |
|
for j in range(w): |
|
pos_enc[i, j, :] = [i/h, j/w, 0] |
|
return image + pos_enc |
|
|
|
def multi_head_attention(self, agent, num_heads=4): |
|
attention_scores = [] |
|
for _ in range(num_heads): |
|
similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1)) |
|
attention_score = similarity / np.sum(similarity) |
|
attention_scores.append(attention_score) |
|
attention = np.mean(attention_scores, axis=0) |
|
return np.expand_dims(attention, axis=-1) |
|
|
|
def multi_scale_perceptual_loss(self, agent_positions): |
|
target_image_resized = self.resize_image((self.target_image + 1) / 2) |
|
target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255) |
|
target_features = self.mobilenet.predict(target_image_preprocessed) |
|
|
|
losses = [] |
|
for agent_position in agent_positions: |
|
agent_image_resized = self.resize_image((agent_position + 1) / 2) |
|
agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255) |
|
agent_features = self.mobilenet.predict(agent_image_preprocessed) |
|
|
|
loss = np.mean((target_features - agent_features)**2) |
|
losses.append(1 / (1 + loss)) |
|
|
|
return np.array(losses) |
|
|
|
def update_agents(self, timestep): |
|
noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)] |
|
|
|
for agent in self.agents: |
|
|
|
predicted_noise = agent.position - self.target_image |
|
|
|
|
|
denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level) |
|
|
|
|
|
agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level) |
|
|
|
|
|
agent.position = np.clip(agent.position, -1, 1) |
|
|
|
def generate_image(self): |
|
self.generated_image = np.mean([agent.position for agent in self.agents], axis=0) |
|
|
|
self.generated_image = (self.generated_image + 1) / 2 |
|
self.generated_image = np.clip(self.generated_image, 0, 1) |
|
|
|
|
|
image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8)) |
|
image_pil = image_pil.filter(ImageFilter.SHARPEN) |
|
self.generated_image = np.array(image_pil) / 255.0 |
|
|
|
def train(self, epochs): |
|
logging.basicConfig(filename='training.log', level=logging.INFO) |
|
|
|
for epoch in tqdm(range(epochs), desc="Training Epochs"): |
|
self.update_agents(epoch) |
|
self.generate_image() |
|
|
|
mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2) |
|
logging.info(f"Epoch {epoch}, MSE: {mse}") |
|
|
|
if epoch % 5 == 0: |
|
print(f"Epoch {epoch}, MSE: {mse}") |
|
self.display_image(self.generated_image, title=f'Epoch {epoch}') |
|
self.current_epoch += 1 |
|
|
|
def display_image(self, image, title=''): |
|
plt.imshow(image) |
|
plt.title(title) |
|
plt.axis('off') |
|
plt.show() |
|
|
|
def display_agent_positions(self, epoch): |
|
fig, ax = plt.subplots() |
|
positions = np.array([agent.position for agent in self.agents]) |
|
ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]]) |
|
ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red') |
|
plt.title(f'Agent Positions at Epoch {epoch}') |
|
plt.show() |
|
|
|
def save_model(self, filename): |
|
model_state = { |
|
'agents': self.agents, |
|
'generated_image': self.generated_image, |
|
'current_epoch': self.current_epoch |
|
} |
|
np.save(filename, model_state) |
|
|
|
def load_model(self, filename): |
|
model_state = np.load(filename, allow_pickle=True).item() |
|
self.agents = model_state['agents'] |
|
self.generated_image = model_state['generated_image'] |
|
self.current_epoch = model_state['current_epoch'] |
|
|
|
def generate_new_image(self, num_steps=500): |
|
for agent in self.agents: |
|
agent.position = np.random.randn(*self.image_shape) |
|
|
|
for step in tqdm(range(num_steps), desc="Generating Image"): |
|
self.update_agents(num_steps - step - 1) |
|
|
|
self.generate_image() |
|
return self.generated_image |
|
|
|
def adjust_limbs(self, arm_position, leg_position): |
|
|
|
|
|
arm_shift = arm_position / 100.0 * 0.2 |
|
leg_shift = leg_position / 100.0 * 0.2 |
|
|
|
|
|
for agent in self.agents: |
|
agent.position[50:100, 50:200, :] += arm_shift |
|
agent.position[150:200, 50:200, :] += leg_shift |
|
|
|
|
|
def train_snn(image_path, num_agents, epochs, arm_position, leg_position, brightness, contrast, color): |
|
snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(256, 256, 3), target_image_path=image_path) |
|
|
|
|
|
image = Image.open(image_path) |
|
image = ImageEnhance.Brightness(image).enhance(brightness) |
|
image = ImageEnhance.Contrast(image).enhance(contrast) |
|
image = ImageEnhance.Color(image).enhance(color) |
|
|
|
snn.target_image = snn.load_target_image(image_path) |
|
|
|
|
|
snn.adjust_limbs(arm_position, leg_position) |
|
|
|
snn.train(epochs=epochs) |
|
snn.save_model('snn_model.npy') |
|
generated_image = snn.generated_image |
|
return generated_image |
|
|
|
def generate_new_image(): |
|
snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(256, 256, 3), target_image_path=None) |
|
snn.load_model('snn_model.npy') |
|
new_image = snn.generate_new_image() |
|
return new_image |
|
|
|
interface = gr.Interface( |
|
fn=train_snn, |
|
inputs=[ |
|
gr.Image(type="filepath", label="Upload Target Image"), |
|
gr.Slider(minimum=500, maximum=2000, value=1000, label="Number of Agents"), |
|
gr.Slider(minimum=10, maximum=100, value=50, label="Number of Epochs"), |
|
gr.Slider(minimum=-100, maximum=100, value=0, label="Arm Position"), |
|
gr.Slider(minimum=-100, maximum=100, value=0, label="Leg Position"), |
|
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Brightness"), |
|
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Contrast"), |
|
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Color Balance") |
|
], |
|
outputs=gr.Image(type="numpy", label="Generated Image"), |
|
title="Swarm Neural Network Image Generation", |
|
description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image. Adjust arm and leg positions, brightness, contrast, and color balance for personalization." |
|
) |
|
|
|
interface.launch() |