File size: 4,968 Bytes
c46d8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
from transformers import pipeline
from transformers import DetrImageProcessor, DetrForObjectDetection
from transformers import CLIPProcessor, CLIPModel
from transformers import BlipProcessor, BlipForQuestionAnswering
#from transformers import YolosImageProcessor, YolosForObjectDetection
from PIL import Image
from functions import *
import io


#load models
@st.cache_resource
def load_models():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50",revision="no_timm")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    sales_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
    sales_model = BlipForQuestionAnswering.from_pretrained(
        "Salesforce/blip-vqa-base",
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
        ).to(device)

    return {
        "detector": model,
        "processor": processor,
        "clip": clip_model,
        "clip process": clip_processor,
        #"t5 token": t5_tokenizer,
        #"t5": t5_model,
        'story_teller': pipeline("text-generation", model="nickypro/tinyllama-15M"),
        "sales process": sales_processor,
        "sales model": sales_model,
        "device": device
    }



def main():
    st.header("πŸ“± Nano AI Image Analyzer")

    uploaded_file= st.file_uploader("upload image")#, type=['.PNG','png','jpg','jpeg'])
    models= load_models()
    st.write('models loaded')

#im2=detect_objects(image_path=image, models= models)
#st.write(im2)
#st.write("done")
#annotated_image= draw_bounding_boxes(image, im2)
#st.image(annotated_image, caption="Detected Objects", use_container_width=True)

#buttons UI
    if uploaded_file is not None:
        image_bytes = uploaded_file.getvalue()
        st.write("Filename:", uploaded_file.name)
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption="Uploaded Image", width=200)  #use_container_width= False,
        
        col1, col2, col3 = st.columns([0.33,0.33,0.33])

        with col1:
            detect= st.button("πŸ” Detect Objects", key="btn1")
        with col2:
            describe= st.button("πŸ“ Describe Image", key="btn2")
        with col3:
            story= st.button("πŸ“– Generate Story", key="btn3",
                             help="story is generated based on caption")


        if detect:
            with st.spinner("Detecting objects..."):
                try:
                    detections = detect_objects(image.copy(), models)
                    annotated_image= draw_bounding_boxes(image, detections)
                    st.image(annotated_image, caption="Detected Objects", use_column_width=True)
                    show_detection_table(detections)
                except:
                    st.write("some error!! try another image")

        elif describe:
            with st.spinner("trying to describe..."):
                description= get_image_description(image.copy(),models)
                st.write(description)

        elif story:
            #st.write('btn3 clicked')
            with st.spinner("getting a story..."):
                description= get_image_description(image.copy(),models)
                story= generate_story(description, models)
                st.write(story)

        # Chat interface
        if "messages" not in st.session_state:
            st.session_state.messages = []
        
        chat_container = st.container(height=400)
        with chat_container:
            
            for message in st.session_state.messages:
                with st.chat_message(message["role"]):
                    st.markdown(message["content"])
            
            if prompt := st.chat_input("Ask about the image"):
                st.session_state.messages.append({"role": "user", "content": prompt})
                with st.chat_message("user"):
                    st.markdown(prompt)
            
                with st.chat_message("assistant"):
                    with st.spinner("Thinking..."):
                        response = answer_question(image, 
                                                prompt, 
                                                models["sales process"], 
                                                models["sales model"], 
                                                models["device"])
                        #response= "response sample"
                    st.markdown(response)
                st.session_state.messages.append({"role": "assistant", "content": response})
    

if __name__ == "__main__":
    main()