bala1802's picture
Update app.py
13227b3 verified
raw
history blame contribute delete
No virus
1.36 kB
import torch
import gradio as gr
import prediction
import model
import diffusion_loss
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = model.initialize_diffusion_model()
def generate(prompt, loss_function=None):
return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function)
def process_input(prompt, loss_function, button):
if button:
if loss_function is None or loss_function == "No Loss":
return generate(prompt, loss_function=None)
elif loss_function == "Blue Channel":
return generate(prompt, loss_function=diffusion_loss.blue_channel)
elif loss_function == "Saturation":
return generate(prompt, loss_function=diffusion_loss.saturation)
elif loss_function == "Elastic Deformation":
return generate(prompt, loss_function=diffusion_loss.elastic_transform)
else:
return generate(prompt, loss_function=None)
else:
return None
iface = gr.Interface(
fn=process_input,
inputs=[
gr.Textbox("prompt", label="Enter Prompt"),
gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'),
gr.Button("Loss Function")],
outputs = gr.Image(type="pil")
)
if __name__ == "__main__":
iface.launch(show_api=False, share=True)