bala1802 commited on
Commit
c01c8d1
1 Parent(s): fa387ad

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ import prediction
5
+ import model
6
+ import diffusion_loss
7
+
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+ pipe = model.initialize_diffusion_model()
11
+
12
+ def generate(prompt, loss_function=None):
13
+ return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function)
14
+
15
+ def process_input(prompt, loss_function, button):
16
+ if button:
17
+ if loss_function is None or loss_function == "No Loss":
18
+ return generate(prompt, loss_function=None)
19
+ elif loss_function == "Blue Channel":
20
+ return generate(prompt, loss_function=diffusion_loss.blue_channel)
21
+ elif loss_function == "Saturation":
22
+ return generate(prompt, loss_function=diffusion_loss.saturation)
23
+ elif loss_function == "Elastic Deformation":
24
+ return generate(prompt, loss_function=diffusion_loss.elastic_transform)
25
+ else:
26
+ return generate(prompt, loss_function=None)
27
+ else:
28
+ return None
29
+
30
+ iface = gr.Interface(
31
+ fn=process_input,
32
+ inputs=[
33
+ gr.Textbox("prompt", label="Enter Prompt"),
34
+ gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'),
35
+ gr.Button("Loss Function")
36
+ ],
37
+
38
+ outputs = gr.Image(type="pil")
39
+ )
40
+
41
+ if __name__ == "__main__":
42
+ iface.launch(show_api=False, share=True)