File size: 2,326 Bytes
77a9008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

import onnx
import onnxruntime as ort
import numpy as np

MODEL_PATH = r"./"
model_name = "animalImageGAN_full.onnx"
ONNX_MODEL_PATH = MODEL_PATH+model_name

onnx_model = onnx.load(ONNX_MODEL_PATH)
onnx.checker.check_model(onnx_model)

rng = np.random.default_rng()
desired_mean = 0
desired_variance = 1

generator_input_size = 50
latent_space_samples = np.random.rand(generator_input_size,1,1).astype(np.float32)
ort_sess = ort.InferenceSession(ONNX_MODEL_PATH)

import gradio as gr


def generateImage():
    random_input = rng.random((generator_input_size, 1, 1),dtype=np.float32)
    current_mean = np.mean(random_input)
    current_variance = np.var(random_input)
    scaled_values = (random_input - current_mean) / np.sqrt(current_variance)
    random_input = scaled_values * np.sqrt(desired_variance) + desired_mean
    
    outputs = ort_sess.run(None, {'input': random_input})
    output = outputs[0]
    denorm_output = np.clip((output * 0.5) + 0.5,0,1)
    #print("i: {}, min:{},max:{}".format(i,denorm_output.min(),denorm_output.max()))
    return denorm_output.transpose(1,2,0)

DESCRIPTION = "<div style='text-align:center'><h1 style='justify-content: center'>Animal Portrait Generator</h1>"
DESCRIPTION += "<p>This is a model trained by using DCGAN</p>"
DESCRIPTION += "<p>More details:</p>"
DESCRIPTION += "<ul><li><a href='https://medium.com/@jiachiewloh/dcgan-animal-image-generator-85e466fb6254'>Article</a></li>"
DESCRIPTION += "<li><a href='https://www.kaggle.com/code/jclohjc/animal-image-generator-dcgan'>Code</a></li></ul>"
DESCRIPTION += "</div>"

with gr.Blocks(css="#img_window {text-align:center; justify-content: center;}\
               .image-container {margin: auto; height: 250px; width: 250px; !important}") as demo:
    # with gr.Row():
    #     gr.Markdown(DESCRIPTION)
    #     with gr.Column():
    #         img_window = gr.Image(interactive=False,height=250,width=250)
    #         with gr.Row():
    #             gr.Button("Generate").click(fn=generateImage,outputs=img_window)
    #             gr.ClearButton().add(img_window)
    
    
    gr.Markdown(DESCRIPTION)
    img_window = gr.Image(interactive=False,elem_id="img_window")
    with gr.Row():
        gr.Button("Generate").click(fn=generateImage,outputs=img_window)
        gr.ClearButton().add(img_window)
demo.launch()