Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ from lavis.models import load_model_and_preprocess
|
|
11 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
|
12 |
import gradio as gr
|
13 |
import torch, gc
|
|
|
|
|
14 |
|
15 |
def prepare_data(image, question):
|
16 |
gc.collect()
|
@@ -20,6 +22,20 @@ def prepare_data(image, question):
|
|
20 |
samples = {"image": image, "text_input": [question]}
|
21 |
return samples
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def gradcam_attention(image, question):
|
24 |
dst_w = 720
|
25 |
samples = prepare_data(image, question)
|
@@ -36,11 +52,11 @@ def gradcam_attention(image, question):
|
|
36 |
return (avg_gradcam * 255).astype(np.uint8)
|
37 |
|
38 |
def generate_cap(image, question, cap_number):
|
|
|
39 |
samples = prepare_data(image, question)
|
40 |
samples = model.forward_itm(samples=samples)
|
41 |
samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5)
|
42 |
-
|
43 |
-
return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]})
|
44 |
|
45 |
def postprocess(text):
|
46 |
for i, ans in enumerate(text):
|
@@ -51,6 +67,7 @@ def postprocess(text):
|
|
51 |
return ans
|
52 |
|
53 |
def generate_answer(image, question):
|
|
|
54 |
samples = prepare_data(image, question)
|
55 |
samples = model.forward_itm(samples=samples)
|
56 |
samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5)
|
@@ -67,7 +84,7 @@ def generate_answer(image, question):
|
|
67 |
pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):])
|
68 |
pred_answer = postprocess(pred_answer)
|
69 |
print(pred_answer, type(pred_answer))
|
70 |
-
return pred_answer
|
71 |
|
72 |
# setup device to use
|
73 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -95,6 +112,7 @@ text_output = gr.Textbox(label="Output Answer")
|
|
95 |
demo = gr.Blocks(title=title)
|
96 |
demo.encrypt = False
|
97 |
cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate')
|
|
|
98 |
|
99 |
with demo:
|
100 |
with gr.Row():
|
@@ -124,10 +142,10 @@ with demo:
|
|
124 |
with gr.Row():
|
125 |
with gr.Column():
|
126 |
cap_btn = gr.Button("Generate caption")
|
127 |
-
cap_btn.click(generate_cap, [raw_image, question, number_cap], [cap_df])
|
128 |
with gr.Column():
|
129 |
anws_btn = gr.Button("Answer")
|
130 |
-
anws_btn.click(generate_answer, [raw_image, question], outputs=text_output)
|
131 |
with gr.Row():
|
132 |
with gr.Column():
|
133 |
# gradcam_btn = gr.Button("Generate Gradcam")
|
@@ -135,5 +153,6 @@ with demo:
|
|
135 |
cap_df.render()
|
136 |
with gr.Column():
|
137 |
text_output.render()
|
|
|
138 |
|
139 |
demo.launch(debug=True)
|
|
|
11 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
|
12 |
import gradio as gr
|
13 |
import torch, gc
|
14 |
+
from gpuinfo import GPUInfo
|
15 |
+
import time
|
16 |
|
17 |
def prepare_data(image, question):
|
18 |
gc.collect()
|
|
|
22 |
samples = {"image": image, "text_input": [question]}
|
23 |
return samples
|
24 |
|
25 |
+
def running_inf(time_start):
|
26 |
+
time_end = time.time()
|
27 |
+
time_diff = time_end - time_start
|
28 |
+
memory = psutil.virtual_memory()
|
29 |
+
gpu_utilization, gpu_memory = GPUInfo.gpu_usage()
|
30 |
+
gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0
|
31 |
+
gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0
|
32 |
+
system_info = f"""
|
33 |
+
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.*
|
34 |
+
*Processing time: {time_diff:.5} seconds.*
|
35 |
+
*GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.*
|
36 |
+
"""
|
37 |
+
return system_info
|
38 |
+
|
39 |
def gradcam_attention(image, question):
|
40 |
dst_w = 720
|
41 |
samples = prepare_data(image, question)
|
|
|
52 |
return (avg_gradcam * 255).astype(np.uint8)
|
53 |
|
54 |
def generate_cap(image, question, cap_number):
|
55 |
+
time_start = time.time()
|
56 |
samples = prepare_data(image, question)
|
57 |
samples = model.forward_itm(samples=samples)
|
58 |
samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5)
|
59 |
+
return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]}), running_inf(time_start)
|
|
|
60 |
|
61 |
def postprocess(text):
|
62 |
for i, ans in enumerate(text):
|
|
|
67 |
return ans
|
68 |
|
69 |
def generate_answer(image, question):
|
70 |
+
time_start = time.time()
|
71 |
samples = prepare_data(image, question)
|
72 |
samples = model.forward_itm(samples=samples)
|
73 |
samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5)
|
|
|
84 |
pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):])
|
85 |
pred_answer = postprocess(pred_answer)
|
86 |
print(pred_answer, type(pred_answer))
|
87 |
+
return pred_answer, running_inf(time_start)
|
88 |
|
89 |
# setup device to use
|
90 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
112 |
demo = gr.Blocks(title=title)
|
113 |
demo.encrypt = False
|
114 |
cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate')
|
115 |
+
system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*")
|
116 |
|
117 |
with demo:
|
118 |
with gr.Row():
|
|
|
142 |
with gr.Row():
|
143 |
with gr.Column():
|
144 |
cap_btn = gr.Button("Generate caption")
|
145 |
+
cap_btn.click(generate_cap, [raw_image, question, number_cap], [cap_df, system_info])
|
146 |
with gr.Column():
|
147 |
anws_btn = gr.Button("Answer")
|
148 |
+
anws_btn.click(generate_answer, [raw_image, question], outputs=[text_output, system_info])
|
149 |
with gr.Row():
|
150 |
with gr.Column():
|
151 |
# gradcam_btn = gr.Button("Generate Gradcam")
|
|
|
153 |
cap_df.render()
|
154 |
with gr.Column():
|
155 |
text_output.render()
|
156 |
+
system_info.render()
|
157 |
|
158 |
demo.launch(debug=True)
|