File size: 1,357 Bytes
c01c8d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13227b3
c01c8d1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import gradio as gr

import prediction
import model
import diffusion_loss

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pipe = model.initialize_diffusion_model()

def generate(prompt, loss_function=None):
    return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function)

def process_input(prompt, loss_function, button):
    if button:
        if loss_function is None or loss_function == "No Loss":
            return generate(prompt, loss_function=None)
        elif loss_function == "Blue Channel":
            return generate(prompt, loss_function=diffusion_loss.blue_channel)
        elif loss_function == "Saturation":
            return generate(prompt, loss_function=diffusion_loss.saturation)
        elif loss_function == "Elastic Deformation":
            return generate(prompt, loss_function=diffusion_loss.elastic_transform)
        else:
            return generate(prompt, loss_function=None)
    else:
        return None

iface = gr.Interface(
    fn=process_input,
    inputs=[
        gr.Textbox("prompt", label="Enter Prompt"),
        gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'),
        gr.Button("Loss Function")],

    outputs = gr.Image(type="pil")
)

if __name__ == "__main__":
    iface.launch(show_api=False, share=True)