File size: 4,480 Bytes
125214f 1812270 eff41fa 125214f e00ca5f 125214f 1948116 d4b85b8 1948116 d4b85b8 1948116 7c8c861 1948116 d4b85b8 7c8c861 1948116 7c8c861 1948116 d37daa7 1948116 7c8c861 1948116 7c8c861 1948116 7c8c861 1948116 7c8c861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import streamlit as st
import torch
import bitsandbytes
import accelerate
import scipy
import copy
from PIL import Image
import torch.nn as nn
import pandas as pd
from my_model.object_detection import detect_and_draw_objects
from my_model.captioner.image_captioning import get_caption
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.state_manager import StateManager
state_manager = StateManager()
class InferenceRunner(StateManager):
def __init__(self):
super().__init__()
self.sample_images = [
"Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
"Files/sample4.jpg", "Files/sample5.jpg", "Files/sample6.jpg",
"Files/sample7.jpg"
]
def answer_question(self, caption, detected_objects_str, question, model):
free_gpu_resources()
answer = model.generate_answer(question, caption, detected_objects_str)
free_gpu_resources()
return answer
def image_qa_app(self, kbvqa):
# Display sample images as clickable thumbnails
self.col1.write("Choose from sample images:")
cols = self.col1.columns(len(self.sample_images))
for idx, sample_image_path in enumerate(self.sample_images):
with cols[idx]:
image = Image.open(sample_image_path)
st.image(image, use_column_width=True)
if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
self.process_new_image(sample_image_path, image, kbvqa)
# Image uploader
uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
if uploaded_image is not None:
self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
# Display and interact with each uploaded/selected image
for image_key, image_data in self.get_images_data().items():
self.col2.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
if not image_data['analysis_done']:
self.col2.text("Cool image, please click 'Analyze Image'..")
if self.col2.button('Analyze Image', key=f'analyze_{image_key}'):
caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
self.update_image_data(image_key, caption, detected_objects_str, True)
# Initialize qa_history for each image
qa_history = image_data.get('qa_history', [])
if image_data['analysis_done']:
question = self.col2.text_input(f"Ask a question about this image ({image_key[-11:]}):", key=f'question_{image_key}')
if self.col2.button('Get Answer', key=f'answer_{image_key}'):
if question not in [q for q, _ in qa_history]:
answer = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
self.add_to_qa_history(image_key, question, answer)
# Display Q&A history for each image
for q, a in qa_history:
st.text(f"Q: {q}\nA: {a}\n")
def run_inference(self):
st.title("Run Inference")
self.initialize_state()
self.set_up_widgets()
st.session_state['settings_changed'] = self.has_state_changed()
if st.session_state['settings_changed']:
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
if st.session_state.method == "Fine-Tuned Model":
if self.col1.button(st.session_state.button_label):
if st.session_state.button_label == "Load Model":
if self.is_model_loaded():
self.col1.text("Model already loaded and no settings were changed:)")
else:
self.load_model()
else:
self.reload_detection_model()
if self.is_model_loaded() and st.session_state.kbvqa.all_models_loaded:
self.image_qa_app(self.get_model())
else:
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.') |