File size: 841 Bytes
d58b45f
 
c771993
 
d58b45f
c771993
bd0c1e9
c771993
 
 
 
252d51a
9f382a2
c771993
 
9f382a2
 
c771993
 
 
369c270
9f382a2
 
 
c771993
9f382a2
 
252d51a
c771993
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import jax
import jax.numpy as jnp
import gradio as gr
import numpy as np

# Dummy diffusion sampling function
def generate_image(seed):
    key = jax.random.PRNGKey(seed)
    img = jax.random.uniform(key, (64, 64, 3))
    img = np.array(img)  # Convert to numpy for display
    return img

with gr.Blocks() as demo:
    gr.Markdown("# ๐ŸŒ€ JAX Diffusion Demo")
    gr.Markdown("Generate random images using JAX diffusion (dummy example).")

    with gr.Row():
        seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, value=42, label="Seed")
        generate_button = gr.Button("Generate Image")
    
    output_image = gr.Image(type="numpy", label="Generated Image")  # ๐Ÿšซ Removed shape=(64, 64)

    generate_button.click(
        fn=generate_image,
        inputs=seed_slider,
        outputs=output_image
    )

demo.launch()