Spaces:
Runtime error
Runtime error
juancopi81
commited on
Commit
•
b31854b
1
Parent(s):
5301a08
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from stable_diffusion_tf.stable_diffusion import StableDiffusion as StableDiffusionPy
|
2 |
+
import keras_cv
|
3 |
+
import gradio as gr
|
4 |
+
from tensorflow import keras
|
5 |
+
|
6 |
+
keras.mixed_precision.set_global_policy("mixed_float16")
|
7 |
+
# load keras model
|
8 |
+
resolution=512
|
9 |
+
sd_dreambooth_model_1=StableDiffusionPy(resolution, resolution, download_weights=False, jit_compile=True)
|
10 |
+
|
11 |
+
diffusion_model_pytorch_weights = keras.utils.get_file(
|
12 |
+
origin="https://huggingface.co/riffusion/riffusion-model-v1/resolve/main/riffusion-model-v1.ckpt",
|
13 |
+
file_hash="99a6eb51c18e16a6121180f3daa69344e571618b195533f67ae94be4eb135a57",
|
14 |
+
)
|
15 |
+
|
16 |
+
sd_dreambooth_model_1.load_weights_from_pytorch_ckpt(diffusion_model_pytorch_weights)
|
17 |
+
|
18 |
+
sd_dreambooth_model_1.diffusion_model.load_weights("/dreambooth_riffusion_model_currulao_v1")
|
19 |
+
|
20 |
+
|
21 |
+
def generate_images(prompt: str, num_steps: int, unconditional_guidance_scale: int, temperature: int):
|
22 |
+
generated_img = sd_dreambooth_model_1.generate(
|
23 |
+
prompt,
|
24 |
+
num_steps=num_steps,
|
25 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
26 |
+
temperature=temperature,
|
27 |
+
batch_size=1,
|
28 |
+
)
|
29 |
+
|
30 |
+
return generated_img
|
31 |
+
|
32 |
+
|
33 |
+
# pass function, input type for prompt, the output for multiple images
|
34 |
+
gr.Interface(
|
35 |
+
title="Keras Dreambooth Riffusion-Currulao",
|
36 |
+
description="""This SD model has been fine-tuned from Riffusion to generate Currulao spectrograms.
|
37 |
+
To generate the concept, use the phrase 'a $currulao song' in your prompt.
|
38 |
+
""",
|
39 |
+
fn=generate_images,
|
40 |
+
inputs=[
|
41 |
+
gr.Textbox(label="Prompt", value="a $currulao song, lo-fi"),
|
42 |
+
gr.Slider(label="Inference steps", value=50),
|
43 |
+
gr.Slider(label="Guidance scale", value=7.5, maximum=15, minimum=0, step=0.5),
|
44 |
+
gr.Slider(label='Temperature', value=1, maximum=1.5, minimum=0, step=0.1),
|
45 |
+
],
|
46 |
+
outputs=[
|
47 |
+
gr.Gallery(show_label=False).style(grid=(1,2)),
|
48 |
+
],
|
49 |
+
examples=[["a $currulao song", "low quality, deformed, dark", 2, 50, 7.5]],
|
50 |
+
).queue().launch(debug=True)
|