m7mdal7aj commited on
Commit
9347b1e
1 Parent(s): c55f56b

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +7 -4
my_model/KBVQA.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import torch
 
3
  import os
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  from typing import Optional
@@ -141,10 +142,12 @@ class KBVQA():
141
 
142
 
143
  def generate_answer(self, question, image):
144
- st.write('image being dected')
145
- st.image(image)
146
- caption = self.get_caption(image)
147
- image_with_boxes, detected_objects_str = self.detect_objects(image)
 
 
148
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
149
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
150
  if num_tokens > self.max_context_window:
 
1
  import streamlit as st
2
  import torch
3
+ import copy
4
  import os
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  from typing import Optional
 
142
 
143
 
144
  def generate_answer(self, question, image):
145
+ img = copy.deepcopy(image)
146
+ st.write('image being detcted')
147
+ st.image(img)
148
+ caption = self.get_caption(img)
149
+ image_with_boxes, detected_objects_str = self.detect_objects(img)
150
+ st.write(detected_objects_str)
151
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
152
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
153
  if num_tokens > self.max_context_window: