File size: 2,022 Bytes
c2e4746
259d0f3
 
c2e4746
 
259d0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e4746
 
 
 
 
 
 
 
 
 
 
 
259d0f3
 
 
c2e4746
 
259d0f3
 
 
 
c2e4746
 
 
259d0f3
 
 
 
c2e4746
 
 
 
 
 
 
 
259d0f3
 
 
c2e4746
259d0f3
 
 
c2e4746
259d0f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from flask import Flask, render_template, request, jsonify
from tensorflow.keras.models import load_model
from numpy.random import randn
import matplotlib.pyplot as plt
import numpy as np
import base64
from io import BytesIO

app = Flask(__name__)

# Load your GAN model from the H5 file
model = load_model('gan.h5')

def generate_latent_points(latent_dim, n_samples):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

def generate_images(model, latent_points):
    generated_images = model.predict(latent_points)
    return generated_images

def plot_generated(examples, n_rows, n_cols, image_size=(80, 80)):
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))

    for i in range(n_rows):
        for j in range(n_cols):
            index = i * n_cols + j
            if index < len(examples):
                axes[i, j].axis('off')
                axes[i, j].imshow(examples[index, :, :])
            else:
                axes[i, j].axis('off')

    buf = BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    plt.close(fig)

    return base64.b64encode(buf.read()).decode('utf-8')

@app.route('/')
def index():
    return render_template('index.html')

import math

@app.route('/generate', methods=['POST'])
def generate():
    latent_dim = 100
    n_samples = max(int(request.form.get('n_samples', 4)), 1)

    # Calculate the number of rows dynamically based on the square root of n_samples
    n_rows = max(int(math.sqrt(n_samples)), 1)
    
    # Calculate the number of columns based on the number of rows
    n_cols = (n_samples + n_rows - 1) // n_rows
    
    latent_points = generate_latent_points(latent_dim, n_samples)
    generated_images = generate_images(model, latent_points)
    generated_images = (generated_images + 1) / 2.0
    img_data = plot_generated(generated_images, n_rows, n_cols)

    return jsonify({'success': True, 'generated_image': img_data})


if __name__ == '__main__':
    app.run(debug=True)