File size: 4,698 Bytes
bc7d231
 
c7c92f9
 
dfda773
58e3cb5
63fc765
d5a60de
7f2e710
bc7d231
85f811b
dc81fd5
 
ca90c3f
 
63fc765
 
 
dc81fd5
 
 
ca90c3f
63fc765
 
 
ca90c3f
bc7d231
 
 
 
 
 
 
 
eedbfb7
bc7d231
 
 
8e2f248
 
fcca3a5
8e2f248
 
fcca3a5
bc7d231
 
 
 
 
 
 
 
 
 
 
4e1ae0e
 
 
 
 
 
 
 
 
 
 
 
7f2e710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f92821
7f2e710
 
4f92821
c55a1ed
4f92821
 
 
 
c55a1ed
4f92821
 
7f2e710
 
4f92821
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import torch
import bitsandbytes
import accelerate
import scipy
from PIL import Image
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
from my_model.object_detection import ObjectDetector

def load_caption_model(blip2=False, instructblip=True):

    if blip2:
        processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
        model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
            model.to('cuda')
        #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
        
    if instructblip:
        model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
            model.to('cuda')
        processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)

    return model, processor

    

def answer_question(image, question, model, processor):
    
    
    image = Image.open(image)

    inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16)

    if isinstance(model, torch.nn.DataParallel):
    # Use the 'module' attribute to access the original model
        out = model.module.generate(**inputs, max_length=100, min_length=20)
    else:
       
        out = model.generate(**inputs, max_length=100, min_length=20)

    answer = processor.decode(out[0], skip_special_tokens=True).strip()
    return answer

st.title("Image Question Answering")

# File uploader for the image
image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])

# Text input for the question
question = st.text_input("Enter your question about the image:")


if st.button("Get Answer"):
    if image is not None and question:
        # Display the image
        st.image(image, use_column_width=True)
        # Get and display the answer
        model, processor = load_caption_model()
        answer = answer_question(image, question, model, processor)
        st.write(answer)
    else:
        st.write("Please upload an image and enter a question.")






# Object Detection

# Object Detection UI in the sidebar
st.sidebar.title("Object Detection")
# Dropdown to select the model
detect_model = st.sidebar.selectbox("Choose a model for object detection:", ["detic", "yolov5"])
# Slider for threshold with default values based on the model
threshold = st.sidebar.slider("Select Detection Threshold", 0.1, 0.9, 0.2 if detect_model == "yolov5" else 0.4)
# Button to trigger object detection
detect_button = st.sidebar.button("Detect Objects")


def perform_object_detection(image, model_name, threshold):
    """
    Perform object detection on the given image using the specified model and threshold.

    Args:
    image (PIL.Image): The image on which to perform object detection.
    model_name (str): The name of the object detection model to use.
    threshold (float): The threshold for object detection.

    Returns:
    PIL.Image, str: The image with drawn bounding boxes and a string of detected objects.
    """
    # Initialize the ObjectDetector
    detector = ObjectDetector()
    # Load the specified model
    detector.load_model(model_name)
    # Perform object detection
    processed_image, detected_objects = detector.detect_objects(image, threshold)
    return processed_image, detected_objects

# Check if the 'Detect Objects' button was clicked
if detect_button:
    if image is not None:
        # Open the uploaded image
        image = Image.open(image)
        # Display the original image
        st.image(image, use_column_width=True, caption="Original Image")
        
        # Perform object detection
        processed_image, detected_objects = perform_object_detection(image, detect_model, threshold)
        
        # Display the image with detected objects
        if isinstance(processed_image, Image.Image):
            st.image(processed_image, use_column_width=True, caption="Image with Detected Objects")
        else:
            st.error("Failed to process image for object detection.")

        # Display the detected objects as text
        st.write(detected_objects)
    else:
        st.write("Please upload an image for object detection.")