from flask import Flask, render_template, request, jsonify from tensorflow.keras.models import load_model from numpy.random import randn import matplotlib.pyplot as plt import numpy as np import base64 from io import BytesIO app = Flask(__name__) # Load your GAN model from the H5 file model = load_model('gan.h5') def generate_latent_points(latent_dim, n_samples): x_input = randn(latent_dim * n_samples) z_input = x_input.reshape(n_samples, latent_dim) return z_input def generate_images(model, latent_points): generated_images = model.predict(latent_points) return generated_images def plot_generated(examples, n_rows, n_cols, image_size=(80, 80)): fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10)) for i in range(n_rows): for j in range(n_cols): index = i * n_cols + j if index < len(examples): axes[i, j].axis('off') axes[i, j].imshow(examples[index, :, :]) else: axes[i, j].axis('off') buf = BytesIO() fig.savefig(buf, format='png') buf.seek(0) plt.close(fig) return base64.b64encode(buf.read()).decode('utf-8') @app.route('/') def index(): return render_template('index.html') import math @app.route('/generate', methods=['POST']) def generate(): latent_dim = 100 n_samples = max(int(request.form.get('n_samples', 4)), 1) # Calculate the number of rows dynamically based on the square root of n_samples n_rows = max(int(math.sqrt(n_samples)), 1) # Calculate the number of columns based on the number of rows n_cols = (n_samples + n_rows - 1) // n_rows latent_points = generate_latent_points(latent_dim, n_samples) generated_images = generate_images(model, latent_points) generated_images = (generated_images + 1) / 2.0 img_data = plot_generated(generated_images, n_rows, n_cols) return jsonify({'success': True, 'generated_image': img_data}) if __name__ == '__main__': app.run(debug=True)