juancopi81 commited on
Commit
b31854b
1 Parent(s): 5301a08

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_diffusion_tf.stable_diffusion import StableDiffusion as StableDiffusionPy
2
+ import keras_cv
3
+ import gradio as gr
4
+ from tensorflow import keras
5
+
6
+ keras.mixed_precision.set_global_policy("mixed_float16")
7
+ # load keras model
8
+ resolution=512
9
+ sd_dreambooth_model_1=StableDiffusionPy(resolution, resolution, download_weights=False, jit_compile=True)
10
+
11
+ diffusion_model_pytorch_weights = keras.utils.get_file(
12
+ origin="https://huggingface.co/riffusion/riffusion-model-v1/resolve/main/riffusion-model-v1.ckpt",
13
+ file_hash="99a6eb51c18e16a6121180f3daa69344e571618b195533f67ae94be4eb135a57",
14
+ )
15
+
16
+ sd_dreambooth_model_1.load_weights_from_pytorch_ckpt(diffusion_model_pytorch_weights)
17
+
18
+ sd_dreambooth_model_1.diffusion_model.load_weights("/dreambooth_riffusion_model_currulao_v1")
19
+
20
+
21
+ def generate_images(prompt: str, num_steps: int, unconditional_guidance_scale: int, temperature: int):
22
+ generated_img = sd_dreambooth_model_1.generate(
23
+ prompt,
24
+ num_steps=num_steps,
25
+ unconditional_guidance_scale=unconditional_guidance_scale,
26
+ temperature=temperature,
27
+ batch_size=1,
28
+ )
29
+
30
+ return generated_img
31
+
32
+
33
+ # pass function, input type for prompt, the output for multiple images
34
+ gr.Interface(
35
+ title="Keras Dreambooth Riffusion-Currulao",
36
+ description="""This SD model has been fine-tuned from Riffusion to generate Currulao spectrograms.
37
+ To generate the concept, use the phrase 'a $currulao song' in your prompt.
38
+ """,
39
+ fn=generate_images,
40
+ inputs=[
41
+ gr.Textbox(label="Prompt", value="a $currulao song, lo-fi"),
42
+ gr.Slider(label="Inference steps", value=50),
43
+ gr.Slider(label="Guidance scale", value=7.5, maximum=15, minimum=0, step=0.5),
44
+ gr.Slider(label='Temperature', value=1, maximum=1.5, minimum=0, step=0.1),
45
+ ],
46
+ outputs=[
47
+ gr.Gallery(show_label=False).style(grid=(1,2)),
48
+ ],
49
+ examples=[["a $currulao song", "low quality, deformed, dark", 2, 50, 7.5]],
50
+ ).queue().launch(debug=True)