tanthinhdt commited on
Commit
3a61959
1 Parent(s): f5403ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -190
app.py CHANGED
@@ -1,190 +1,190 @@
1
- import torch
2
- import urllib
3
- import streamlit as st
4
- from io import BytesIO
5
- from time import time
6
- from PIL import Image
7
- from transformers import AutoModelForVision2Seq, AutoProcessor
8
-
9
-
10
- def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image:
11
- """
12
- Scale an image to a target height while maintaining the aspect ratio.
13
-
14
- Parameters
15
- ----------
16
- image : Image.Image
17
- The image to scale.
18
- target_height : int, optional (default=500)
19
- The target height of the image.
20
-
21
- Returns
22
- -------
23
- Image.Image
24
- The scaled image.
25
- """
26
- width, height = image.size
27
- aspect_ratio = width / height
28
- target_width = int(aspect_ratio * target_height)
29
- return image.resize((target_width, target_height))
30
-
31
-
32
- def upload_image() -> None:
33
- """
34
- Upload an image.
35
- """
36
- if st.session_state.file_uploader is not None:
37
- st.session_state.image = Image.open(st.session_state.file_uploader)
38
-
39
-
40
- def read_image_from_url() -> None:
41
- """
42
- Read an image from a URL.
43
- """
44
- if st.session_state.image_url is not None:
45
- with urllib.request.urlopen(st.session_state.image_url) as response:
46
- st.session_state.image = Image.open(BytesIO(response.read()))
47
-
48
-
49
- def inference() -> None:
50
- """
51
- Perform inference on an image and generate a caption.
52
- """
53
- start_time = time()
54
- outputs = st.session_state.processor(
55
- images=st.session_state.image,
56
- return_tensors="pt",
57
- )
58
- outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()}
59
- st.session_state.model.to(st.session_state.device.lower())
60
- logits = st.session_state.model.generate(
61
- **outputs,
62
- max_length=st.session_state.max_length,
63
- num_beams=st.session_state.num_beams,
64
- )
65
- caption = st.session_state.processor.decode(
66
- logits[0], skip_special_tokens=True
67
- )
68
- end_time = time()
69
-
70
- st.session_state.inference_time = round(end_time - start_time, 2)
71
- st.session_state.caption = caption
72
-
73
- st.session_state.model.to("cpu")
74
- torch.cuda.empty_cache()
75
-
76
-
77
- def main() -> None:
78
- """
79
- Main function for the Streamlit app.
80
- """
81
- if "model" not in st.session_state:
82
- st.session_state.model = AutoModelForVision2Seq.from_pretrained(
83
- "tanthinhdt/blip-base_with-pretrained_flickr30k",
84
- cache_dir="models/huggingface",
85
- )
86
- st.session_state.model.eval()
87
- if "processor" not in st.session_state:
88
- st.session_state.processor = AutoProcessor.from_pretrained(
89
- "Salesforce/blip-image-captioning-base",
90
- cache_dir="models/huggingface",
91
- )
92
- if "image" not in st.session_state:
93
- st.session_state.image = None
94
- if "caption" not in st.session_state:
95
- st.session_state.caption = None
96
- if "inference_time" not in st.session_state:
97
- st.session_state.inference_time = 0.0
98
-
99
- # Set page configuration
100
- st.set_page_config(
101
- page_title="Image Captioning App",
102
- page_icon="📸",
103
- initial_sidebar_state="expanded",
104
- )
105
-
106
- # Set sidebar layout
107
- st.sidebar.header("Workspace")
108
- st.sidebar.file_uploader(
109
- "Upload an image",
110
- type=["jpg", "jpeg", "png"],
111
- accept_multiple_files=False,
112
- on_change=upload_image,
113
- key="file_uploader",
114
- help="Upload an image to generate a caption.",
115
- )
116
- st.sidebar.text_input(
117
- "Image URL",
118
- on_change=read_image_from_url,
119
- key="image_url",
120
- help="Enter the URL of an image to generate a caption.",
121
- )
122
- st.sidebar.divider()
123
- st.sidebar.header("Settings")
124
- st.sidebar.selectbox(
125
- label="Device",
126
- options=["CPU", "CUDA"],
127
- index=["CPU", "CUDA"].index(st.session_state.device) if "device" in st.session_state else 1,
128
- key="device",
129
- help="The device to use for inference.",
130
- )
131
- st.sidebar.number_input(
132
- label="Max length",
133
- min_value=32,
134
- max_value=128,
135
- value=64,
136
- step=1,
137
- key="max_length",
138
- help="The maximum length of the generated caption.",
139
- )
140
- st.sidebar.number_input(
141
- label="Number of beams",
142
- min_value=1,
143
- max_value=10,
144
- value=4,
145
- step=1,
146
- key="num_beams",
147
- help="The number of beams to use during decoding.",
148
- )
149
-
150
- # Set main layout
151
- st.markdown(
152
- """
153
- <h1 style='text-align: center;'>
154
- Image Captioning
155
- </h1>
156
- """,
157
- unsafe_allow_html=True,
158
- )
159
- st.divider()
160
- image_container = st.container(height=450)
161
- st.divider()
162
- col_1, col_2, col_3 = st.columns([1, 1, 2])
163
- resolution_display = col_1.empty()
164
- runtime_display = col_2.empty()
165
- caption_display = col_3.empty()
166
-
167
- # Display the image and generate a caption
168
- if st.session_state.image is not None:
169
- image_container.image(scale_image(st.session_state.image, target_height=400))
170
-
171
- resolution_display.metric(
172
- label="Image Resolution",
173
- value=f"{st.session_state.image.width}x{st.session_state.image.height}",
174
- )
175
-
176
- with st.spinner("Generating caption..."):
177
- inference()
178
-
179
- caption_display.text_area(
180
- label="Caption",
181
- value=st.session_state.caption,
182
- )
183
- runtime_display.metric(
184
- label="Inference Time",
185
- value=f"{st.session_state.inference_time}s",
186
- )
187
-
188
-
189
- if __name__ == "__main__":
190
- main()
 
1
+ import torch
2
+ import urllib
3
+ import streamlit as st
4
+ from io import BytesIO
5
+ from time import time
6
+ from PIL import Image
7
+ from transformers import AutoModelForVision2Seq, AutoProcessor
8
+
9
+
10
+ def scale_image(image: Image.Image, target_height: int = 500) -> Image.Image:
11
+ """
12
+ Scale an image to a target height while maintaining the aspect ratio.
13
+
14
+ Parameters
15
+ ----------
16
+ image : Image.Image
17
+ The image to scale.
18
+ target_height : int, optional (default=500)
19
+ The target height of the image.
20
+
21
+ Returns
22
+ -------
23
+ Image.Image
24
+ The scaled image.
25
+ """
26
+ width, height = image.size
27
+ aspect_ratio = width / height
28
+ target_width = int(aspect_ratio * target_height)
29
+ return image.resize((target_width, target_height))
30
+
31
+
32
+ def upload_image() -> None:
33
+ """
34
+ Upload an image.
35
+ """
36
+ if st.session_state.file_uploader is not None:
37
+ st.session_state.image = Image.open(st.session_state.file_uploader)
38
+
39
+
40
+ def read_image_from_url() -> None:
41
+ """
42
+ Read an image from a URL.
43
+ """
44
+ if st.session_state.image_url is not None:
45
+ with urllib.request.urlopen(st.session_state.image_url) as response:
46
+ st.session_state.image = Image.open(BytesIO(response.read()))
47
+
48
+
49
+ def inference() -> None:
50
+ """
51
+ Perform inference on an image and generate a caption.
52
+ """
53
+ start_time = time()
54
+ outputs = st.session_state.processor(
55
+ images=st.session_state.image,
56
+ return_tensors="pt",
57
+ )
58
+ outputs = {k: v.to(st.session_state.device.lower()) for k, v in outputs.items()}
59
+ st.session_state.model.to(st.session_state.device.lower())
60
+ logits = st.session_state.model.generate(
61
+ **outputs,
62
+ max_length=st.session_state.max_length,
63
+ num_beams=st.session_state.num_beams,
64
+ )
65
+ caption = st.session_state.processor.decode(
66
+ logits[0], skip_special_tokens=True
67
+ )
68
+ end_time = time()
69
+
70
+ st.session_state.inference_time = round(end_time - start_time, 2)
71
+ st.session_state.caption = caption
72
+
73
+ st.session_state.model.to("cpu")
74
+ torch.cuda.empty_cache()
75
+
76
+
77
+ def main() -> None:
78
+ """
79
+ Main function for the Streamlit app.
80
+ """
81
+ if "model" not in st.session_state:
82
+ st.session_state.model = AutoModelForVision2Seq.from_pretrained(
83
+ "tanthinhdt/blip-base_with-pretrained_flickr30k",
84
+ cache_dir="models/huggingface",
85
+ )
86
+ st.session_state.model.eval()
87
+ if "processor" not in st.session_state:
88
+ st.session_state.processor = AutoProcessor.from_pretrained(
89
+ "Salesforce/blip-image-captioning-base",
90
+ cache_dir="models/huggingface",
91
+ )
92
+ if "image" not in st.session_state:
93
+ st.session_state.image = None
94
+ if "caption" not in st.session_state:
95
+ st.session_state.caption = None
96
+ if "inference_time" not in st.session_state:
97
+ st.session_state.inference_time = 0.0
98
+
99
+ # Set page configuration
100
+ st.set_page_config(
101
+ page_title="Image Captioning App",
102
+ page_icon="📸",
103
+ initial_sidebar_state="expanded",
104
+ )
105
+
106
+ # Set sidebar layout
107
+ st.sidebar.header("Workspace")
108
+ st.sidebar.file_uploader(
109
+ "Upload an image",
110
+ type=["jpg", "jpeg", "png"],
111
+ accept_multiple_files=False,
112
+ on_change=upload_image,
113
+ key="file_uploader",
114
+ help="Upload an image to generate a caption.",
115
+ )
116
+ st.sidebar.text_input(
117
+ "Image URL",
118
+ on_change=read_image_from_url,
119
+ key="image_url",
120
+ help="Enter the URL of an image to generate a caption.",
121
+ )
122
+ st.sidebar.divider()
123
+ st.sidebar.header("Settings")
124
+ st.sidebar.selectbox(
125
+ label="Device",
126
+ options=["CPU", "CUDA"],
127
+ index=1 if torch.cuda.is_available() else 0,
128
+ key="device",
129
+ help="The device to use for inference.",
130
+ )
131
+ st.sidebar.number_input(
132
+ label="Max length",
133
+ min_value=32,
134
+ max_value=128,
135
+ value=64,
136
+ step=1,
137
+ key="max_length",
138
+ help="The maximum length of the generated caption.",
139
+ )
140
+ st.sidebar.number_input(
141
+ label="Number of beams",
142
+ min_value=1,
143
+ max_value=10,
144
+ value=4,
145
+ step=1,
146
+ key="num_beams",
147
+ help="The number of beams to use during decoding.",
148
+ )
149
+
150
+ # Set main layout
151
+ st.markdown(
152
+ """
153
+ <h1 style='text-align: center;'>
154
+ Image Captioning
155
+ </h1>
156
+ """,
157
+ unsafe_allow_html=True,
158
+ )
159
+ st.divider()
160
+ image_container = st.container(height=450)
161
+ st.divider()
162
+ col_1, col_2, col_3 = st.columns([1, 1, 2])
163
+ resolution_display = col_1.empty()
164
+ runtime_display = col_2.empty()
165
+ caption_display = col_3.empty()
166
+
167
+ # Display the image and generate a caption
168
+ if st.session_state.image is not None:
169
+ image_container.image(scale_image(st.session_state.image, target_height=400))
170
+
171
+ resolution_display.metric(
172
+ label="Image Resolution",
173
+ value=f"{st.session_state.image.width}x{st.session_state.image.height}",
174
+ )
175
+
176
+ with st.spinner("Generating caption..."):
177
+ inference()
178
+
179
+ caption_display.text_area(
180
+ label="Caption",
181
+ value=st.session_state.caption,
182
+ )
183
+ runtime_display.metric(
184
+ label="Inference Time",
185
+ value=f"{st.session_state.inference_time}s",
186
+ )
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()