Spaces:
Sleeping
Sleeping
File size: 1,357 Bytes
c01c8d1 13227b3 c01c8d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import gradio as gr
import prediction
import model
import diffusion_loss
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = model.initialize_diffusion_model()
def generate(prompt, loss_function=None):
return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function)
def process_input(prompt, loss_function, button):
if button:
if loss_function is None or loss_function == "No Loss":
return generate(prompt, loss_function=None)
elif loss_function == "Blue Channel":
return generate(prompt, loss_function=diffusion_loss.blue_channel)
elif loss_function == "Saturation":
return generate(prompt, loss_function=diffusion_loss.saturation)
elif loss_function == "Elastic Deformation":
return generate(prompt, loss_function=diffusion_loss.elastic_transform)
else:
return generate(prompt, loss_function=None)
else:
return None
iface = gr.Interface(
fn=process_input,
inputs=[
gr.Textbox("prompt", label="Enter Prompt"),
gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'),
gr.Button("Loss Function")],
outputs = gr.Image(type="pil")
)
if __name__ == "__main__":
iface.launch(show_api=False, share=True) |