Rahul Dubey commited on
Commit
463297f
·
1 Parent(s): 817c993

new file: actionFunctions.py

Browse files

new file: app.py
new file: components.py
new file: models/__init__.py
new file: models/base.py
new file: models/realistic_vision_v6b1.py
new file: models/sdxl.py
new file: requirements.txt

actionFunctions.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.base import Model
2
+ import gradio as gr
3
+
4
+ model = None
5
+ predict = None
6
+
7
+ def load_model(model_name):
8
+
9
+ global model
10
+ model = Model(model_name)
11
+ global predict
12
+ predict = model.predict
13
+
14
+ return model.getModelState()
15
+
16
+ def fn_gen_text_to_text(x):
17
+ text = predict
18
+ return text
19
+
20
+ def fn_gen_text_to_image(x):
21
+ image = predict(x)
22
+ return image
23
+
24
+ def fn_gen_image_to_text(x):
25
+ text = predict(x)
26
+ return text
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from components import *
4
+ from actionFunctions import *
5
+
6
+
7
+ def create_interface():
8
+
9
+ with gr.Blocks(
10
+ theme=gr.themes.Monochrome(),
11
+ css="footer{display:none !important};"
12
+ ) as demo:
13
+
14
+ gr.Markdown("Welcome to tangibleAI")
15
+
16
+
17
+ #TEXT-TO-TEXT GENERATOR INTERFACE
18
+ ttt_interface()
19
+
20
+ #TEXT-TO-IMAGE GENERATOR INTERFACE
21
+ tti_interface()
22
+
23
+ #IMAGE-TO-TEXT GENERATOR INTERFACE
24
+ itt_interface()
25
+
26
+ return demo
27
+
28
+ if __name__ == '__main__':
29
+ demo = create_interface()
30
+ demo.launch()
components.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from actionFunctions import *
3
+
4
+ def ttt_interface():
5
+
6
+ with gr.Tab("Text-to-Text"):
7
+
8
+ with gr.Row():
9
+ ttt_text_input = gr.TextArea(label="Input Text")
10
+ ttt_text_output = gr.TextArea(label="Generated Text")
11
+
12
+ with gr.Column():
13
+ ttt_models = gr.Dropdown(["Model 1","Model 2", "Model 3"], label='Select a Model')
14
+ ttt_models.change(load_model, ttt_models, None)
15
+
16
+ with gr.Row():
17
+ text_to_text_clear_button = gr.Button("Clear")
18
+ text_to_text_gen_button = gr.Button("Generate")
19
+
20
+ text_to_text_clear_button.click(fn_gen_text_to_text, None, ttt_text_output)
21
+ text_to_text_gen_button.click(fn_gen_text_to_text, ttt_text_input, ttt_text_output)
22
+
23
+ def tti_interface():
24
+ with gr.Tab("Text-to-Image"):
25
+
26
+ with gr.Column():
27
+ tti_image_output = gr.Image(label="Generated Image")
28
+ tti_text_input = gr.Textbox(label="Input Text")
29
+
30
+ with gr.Column():
31
+ tti_models = gr.Dropdown(["SDXL","REALV6B1"], label='Select a Model')
32
+ output = gr.Textbox(label='Progress Output')
33
+ tti_models.change(load_model, tti_models, output)
34
+
35
+ text_to_image_gen_button = gr.Button("Generate")
36
+ text_to_image_gen_button.click(fn_gen_text_to_image, tti_text_input, tti_image_output)
37
+
38
+ def itt_interface():
39
+ with gr.Tab("Image-to-text"):
40
+
41
+ with gr.Row():
42
+ itt_image_input = gr.Image(label="Input Image")
43
+ itt_text_output = gr.Textbox(label="Generated Text")
44
+
45
+ with gr.Column():
46
+ itt_models = gr.Dropdown(["Model 1","Model 2", "Model 3"], label='Select a Model')
47
+ itt_models.change(load_model, itt_models, None)
48
+
49
+ image_to_text_gen_button = gr.Button("Generate")
50
+ image_to_text_gen_button.click(fn_gen_image_to_text, itt_image_input, itt_text_output)
models/__init__.py ADDED
File without changes
models/base.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .realistic_vision_v6b1 import RealV6B1
2
+ from .sdxl import SDXL
3
+
4
+ class Model:
5
+
6
+ def __init__(self, modelName):
7
+ self.modelName = modelName
8
+
9
+ if self.modelName == 'SDXL':
10
+ self.modelObj = SDXL()
11
+ self.model = self.modelObj.load_model()
12
+
13
+ elif self.modelName == 'REALV6B1':
14
+ self.modelObj = RealV6B1()
15
+ self.model = self.modelObj.load_model()
16
+
17
+ else:
18
+ self.modelObj = None
19
+ self.model = None
20
+
21
+ def getModelState(self):
22
+ if self.modelObj is None:
23
+ return "Model Not Loaded"
24
+
25
+ else:
26
+ return "Model Loaded"
27
+
28
+ def predict(self, prompt):
29
+ return self.modelObj.predict(self.model, prompt)
models/realistic_vision_v6b1.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+
3
+ class RealV6B1:
4
+
5
+ def __init__(self):
6
+ pass
7
+
8
+ def load_model(self):
9
+ pipeline = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE",)
10
+ pipeline.safety_checker = None
11
+
12
+ return pipeline
13
+
14
+ def predict(self, pipeline, prompt):
15
+ images = pipeline(prompt=prompt,).images[0]
16
+
17
+ return images
models/sdxl.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+
3
+
4
+ class SDXL:
5
+
6
+ def __init__(self):
7
+ pass
8
+
9
+ def load_model(self):
10
+ pipeline = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE",)
11
+ pipeline.safety_checker = None
12
+
13
+ return pipeline
14
+
15
+ def predict(self, pipeline, prompt):
16
+ images = pipeline(prompt=prompt,).images[0]
17
+
18
+ return images
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ safetensors
5
+ diffusers
6
+ accelerate