import torch import gradio as gr import prediction import model import diffusion_loss device = 'cuda' if torch.cuda.is_available() else 'cpu' pipe = model.initialize_diffusion_model() def generate(prompt, loss_function=None): return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function) def process_input(prompt, loss_function, button): if button: if loss_function is None or loss_function == "No Loss": return generate(prompt, loss_function=None) elif loss_function == "Blue Channel": return generate(prompt, loss_function=diffusion_loss.blue_channel) elif loss_function == "Saturation": return generate(prompt, loss_function=diffusion_loss.saturation) elif loss_function == "Elastic Deformation": return generate(prompt, loss_function=diffusion_loss.elastic_transform) else: return generate(prompt, loss_function=None) else: return None iface = gr.Interface( fn=process_input, inputs=[ gr.Textbox("prompt", label="Enter Prompt"), gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'), gr.Button("Loss Function")], outputs = gr.Image(type="pil") ) if __name__ == "__main__": iface.launch(show_api=False, share=True)