palitrajarshi commited on
Commit
6d07f80
β€’
1 Parent(s): 96072ff

Update pages/Captionize.py

Browse files
Files changed (1) hide show
  1. pages/Captionize.py +12 -19
pages/Captionize.py CHANGED
@@ -5,6 +5,11 @@ import requests
5
  import streamlit as st
6
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
7
 
 
 
 
 
 
8
 
9
  st.set_page_config(page_title="Captionize")
10
 
@@ -26,30 +31,18 @@ div.stButton > button:hover {
26
  }
27
  </style>""", unsafe_allow_html=True)
28
 
29
-
30
- device='cpu'
31
- encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
32
- decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
33
- model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
34
- feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
35
- tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
36
- model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
37
-
38
- def predict(image,max_length=64, num_beams=4):
39
- #image = image.convert('RGB')
40
- image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
41
- image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
42
- clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
43
- caption_ids = model.generate(image, max_length = max_length)[0]
44
- caption_text = clean_text(tokenizer.decode(caption_ids))
45
- return caption_text
46
-
47
  pic = st.file_uploader(label="Please upload any Image here 😎",type=['png', 'jpeg', 'jpg'], help="Only 'png', 'jpeg' or 'jpg' formats allowed")
48
 
 
 
 
 
49
 
50
  button = st.button("Generate Caption")
 
51
 
52
  if button:
 
53
  # Get Response
54
- caption = predict(pic)
55
  st.write(caption)
 
5
  import streamlit as st
6
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
7
 
8
+ from PIL import Image
9
+ import requests
10
+ from langchain.indexes import VectorstoreIndexCreator
11
+ from langchain.document_loaders import ImageCaptionLoader
12
+
13
 
14
  st.set_page_config(page_title="Captionize")
15
 
 
31
  }
32
  </style>""", unsafe_allow_html=True)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  pic = st.file_uploader(label="Please upload any Image here 😎",type=['png', 'jpeg', 'jpg'], help="Only 'png', 'jpeg' or 'jpg' formats allowed")
35
 
36
+ #Image.open(requests.get(pic, stream=True).raw).convert("RGB")
37
+ loader = ImageCaptionLoader(path_images=pic)
38
+ list_docs = loader.load()
39
+ index = VectorstoreIndexCreator().from_loaders([loader])
40
 
41
  button = st.button("Generate Caption")
42
+ query = st.text_area("Enter your query πŸ”")
43
 
44
  if button:
45
+ Image.open(requests.get(pic, stream=True).raw).convert("RGB")
46
  # Get Response
47
+ caption = index.query(query)
48
  st.write(caption)