############################################################################################################################# # Filename : app.py # Description: A Streamlit application to utilize five models back to back # Models used: # 1. Visual Question Answering (VQA). # 2. Fill-Mask. # 3. Text2text Generation. # 4. Text Generation. # 5. Topic. # Author : Georgios Ioannou # # Copyright © 2024 by Georgios Ioannou ############################################################################################################################# # Import libraries. import streamlit as st # Build the GUI of the application. import torch # Load Salesforce/blip model(s) on GPU. from bertopic import BERTopic # Topic model inference. from PIL import Image # Open and identify a given image file. from transformers import ( pipeline, BlipProcessor, BlipForQuestionAnswering, ) # VQA model inference. ############################################################################################################################# # Function to apply local CSS. def local_css(file_name): with open(file_name) as f: st.markdown(f"", unsafe_allow_html=True) ############################################################################################################################# # Model 1. # Model 1 gets input from the user. # User -> Model 1 # Load the Visual Question Answering (VQA) model directly. # Using transformers. @st.cache_resource def load_model_blip(): blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") blip_model_base = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base" ) # Backup model. # blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") # blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") # return blip_processor_large, blip_model_large return blip_processor_base, blip_model_base # General function for any Salesforce/blip model(s). # VQA model. def generate_answer_blip(processor, model, image, question): # Prepare image + question. inputs = processor(images=image, text=question, return_tensors="pt") generated_ids = model.generate(**inputs, max_length=50) generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True) return generated_answer # Generate answer from the Salesforce/blip model(s). # VQA model. @st.cache_resource def generate_answer(image, question): answer_blip_base = generate_answer_blip( processor=blip_processor_base, model=blip_model_base, image=image, question=question, ) # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question) # return answer_blip_large return answer_blip_base ############################################################################################################################# # Model 2. # Model 2 gets input from Model 1. # User -> Model 1 -> Model 2 @st.cache_resource def load_model_fill_mask(): return pipeline(task="fill-mask", model="bert-base-uncased") ############################################################################################################################# # Model 3. # Model 3 gets input from Model 2. # User -> Model 1 -> Model 2 -> Model 3 @st.cache_resource def load_model_text2text_generation(): return pipeline( task="text2text-generation", model="facebook/blenderbot-400M-distill" ) ############################################################################################################################# # Model 4. # Model 4 gets input from Model 3. # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 @st.cache_resource def load_model_fill_text_generation(): return pipeline(task="text-generation", model="gpt2") ############################################################################################################################# # Model 5. # Model 5 gets input from Model 4. # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 @st.cache_resource def load_model_bertopic1(): return BERTopic.load(path="davanstrien/chat_topics") @st.cache_resource def load_model_bertopic2(): return BERTopic.load(path="MaartenGr/BERTopic_ArXiv") ############################################################################################################################# # Page title and favicon. st.set_page_config(page_title="Visual Question Answering", page_icon="❓") ############################################################################################################################# # Load the Salesforce/blip model directly. if torch.cuda.is_available(): device = torch.device("cuda") # elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # device = torch.device("mps") else: device = torch.device("cpu") blip_processor_base, blip_model_base = load_model_blip() blip_model_base.to(device) ############################################################################################################################# # Main function to create the Streamlit web application. # # 5 MODEL INFERENCES. # User Input = Image + Question About The Image. # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 def main(): try: ##################################################################################################################### # Load CSS. local_css("styles/style.css") ##################################################################################################################### # Title. title = f"""

Georgios Ioannou's Visual Question Answering

""" st.markdown(title, unsafe_allow_html=True) # st.title("ChefBot - Automated Recipe Assistant") ##################################################################################################################### # Subtitle. subtitle = f"""

CUNY Tech Prep Tutorial 4

""" st.markdown(subtitle, unsafe_allow_html=True) ##################################################################################################################### # Image. image = "./ctp.png" left_co, cent_co, last_co = st.columns(3) with cent_co: st.image(image=image) ##################################################################################################################### # User input (Image). image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if image is not None: bytes_data = image.getvalue() with open(image.name, "wb") as file: file.write(bytes_data) st.image(image, caption="Uploaded Image.", use_column_width=True) raw_image = Image.open(image.name).convert("RGB") # User input (Question). question = st.text_input("What's your question?") ############################################################################################################# if question != "": # Model 1. with st.spinner( text="VQA inference..." ): # Spinner to keep the application interactive. # Model inference. answer = generate_answer(raw_image, question)[0] st.success(f"VQA: {answer}") bbu_pipeline = load_model_fill_mask() text = ( "I love " + answer + " and I would like to know how to [MASK]." ) ######################################################################################################### # Model 2. with st.spinner( text="Fill-Mask inference..." ): # Spinner to keep the application interactive. # Model inference. bbu_pipeline_output = bbu_pipeline(text) bbu_output = bbu_pipeline_output[0]["sequence"] st.success(f"Fill-Mask: {bbu_output}") facebook_pipeline = load_model_text2text_generation() utterance = bbu_output ######################################################################################################### # Model 3. with st.spinner( text="Text2text Generation inference..." ): # Spinner to keep the application interactive. # Model inference. facebook_pipeline_output = facebook_pipeline(utterance) facebook_output = facebook_pipeline_output[0]["generated_text"] st.success(f"Text2text Generation: {facebook_output}") gpt2_pipeline = load_model_fill_text_generation() ######################################################################################################### # Model 4. with st.spinner( text="Fill Text Generation inference..." ): # Spinner to keep the application interactive. # Model inference. gpt2_pipeline_output = gpt2_pipeline(facebook_output) gpt2_output = gpt2_pipeline_output[0]["generated_text"] st.success(f"Fill Text Generation: {gpt2_output}") ######################################################################################################### # Model 5. topic_model_1 = load_model_bertopic1() topic, prob = topic_model_1.transform(gpt2_pipeline_output) topic_model_1_output = topic_model_1.get_topic_info(topic[0])[ "Representation" ][0] st.success( f"Topic(s) from davanstrien/chat_topics: {topic_model_1_output}" ) topic_model_2 = load_model_bertopic2() topic, prob = topic_model_2.transform(gpt2_pipeline_output) topic_model_2_output = topic_model_2.get_topic_info(topic[0])[ "Representation" ][0] st.success( f"Topic(s) from MaartenGr/BERTopic_ArXiv: {topic_model_1_output}" ) except Exception as e: # General exception/error handling. st.error(e) # GitHub repository of author. st.markdown( f"""

Check out our GitHub repository

""", unsafe_allow_html=True, ) ############################################################################################################################# if __name__ == "__main__": main()