Lwasinam commited on
Commit
98fb2b8
1 Parent(s): 1a2538f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -157,24 +157,27 @@ def start():
157
  def main():
158
  st.title("Image Captioning with Transformer Models")
159
  image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
 
160
  if image is not None:
161
-
162
  # image_bytes = uploaded_file.getvalue()
163
  # image = image_base64(image_bytes)
164
  # image = get_image(uploaded_file)
165
-
166
- accelerator = Accelerator()
167
- device = accelerator.device
168
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
- config = get_config()
170
- tokenizer = get_or_build_tokenizer(config)
171
- model = get_model(config, len(tokenizer))
172
- model = accelerator.prepare(model)
173
- accelerator.load_state('models/')
174
- # model = get_model(config, len(tokenizer))
175
- # model.to(device)
176
-
177
- text_output = process(model, image, tokenizer, device)
 
 
178
  st.write(text_output)
179
 
180
  if __name__ == "__main__":
 
157
  def main():
158
  st.title("Image Captioning with Transformer Models")
159
  image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
160
+
161
  if image is not None:
162
+ st.image(image, use_column_width=True)
163
  # image_bytes = uploaded_file.getvalue()
164
  # image = image_base64(image_bytes)
165
  # image = get_image(uploaded_file)
166
+ with st.empty():
167
+ st.write("Processing the image... Please wait.")
168
+ accelerator = Accelerator()
169
+ device = accelerator.device
170
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
171
+ config = get_config()
172
+ tokenizer = get_or_build_tokenizer(config)
173
+ model = get_model(config, len(tokenizer))
174
+ model = accelerator.prepare(model)
175
+ accelerator.load_state('models/')
176
+ # model = get_model(config, len(tokenizer))
177
+ # model.to(device)
178
+
179
+
180
+ text_output = process(model, image, tokenizer, device)
181
  st.write(text_output)
182
 
183
  if __name__ == "__main__":