############################################################################################################################# # Filename : app.py # Description: A Streamlit application to utilize five models back to back # Models used: # 1. Visual Question Answering (VQA). # 2. Fill-Mask. # 3. Text2text Generation. # 4. Text Generation. # 5. Topic. # Author : Georgios Ioannou # # Copyright © 2024 by Georgios Ioannou ############################################################################################################################# # Import libraries. import streamlit as st # Build the GUI of the application. import torch # Load Salesforce/blip model(s) on GPU. from bertopic import BERTopic # Topic model inference. from PIL import Image # Open and identify a given image file. from transformers import ( pipeline, BlipProcessor, BlipForQuestionAnswering, ) # VQA model inference. ############################################################################################################################# # Function to apply local CSS. def local_css(file_name): with open(file_name) as f: st.markdown(f"", unsafe_allow_html=True) ############################################################################################################################# # Model 1. # Model 1 gets input from the user. # User -> Model 1 # Load the Visual Question Answering (VQA) model directly. # Using transformers. @st.cache_resource def load_model_blip(): blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") blip_model_base = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base" ) # Backup model. # blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") # blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") # return blip_processor_large, blip_model_large return blip_processor_base, blip_model_base # General function for any Salesforce/blip model(s). # VQA model. def generate_answer_blip(processor, model, image, question): # Prepare image + question. inputs = processor(images=image, text=question, return_tensors="pt") generated_ids = model.generate(**inputs, max_length=50) generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True) return generated_answer # Generate answer from the Salesforce/blip model(s). # VQA model. @st.cache_resource def generate_answer(image, question): answer_blip_base = generate_answer_blip( processor=blip_processor_base, model=blip_model_base, image=image, question=question, ) # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question) # return answer_blip_large return answer_blip_base ############################################################################################################################# # Model 2. # Model 2 gets input from Model 1. # User -> Model 1 -> Model 2 @st.cache_resource def load_model_fill_mask(): return pipeline(task="fill-mask", model="bert-base-uncased") ############################################################################################################################# # Model 3. # Model 3 gets input from Model 2. # User -> Model 1 -> Model 2 -> Model 3 @st.cache_resource def load_model_text2text_generation(): return pipeline( task="text2text-generation", model="facebook/blenderbot-400M-distill" ) ############################################################################################################################# # Model 4. # Model 4 gets input from Model 3. # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 @st.cache_resource def load_model_fill_text_generation(): return pipeline(task="text-generation", model="gpt2") ############################################################################################################################# # Model 5. # Model 5 gets input from Model 4. # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 @st.cache_resource def load_model_bertopic1(): return BERTopic.load(path="davanstrien/chat_topics") @st.cache_resource def load_model_bertopic2(): return BERTopic.load(path="MaartenGr/BERTopic_ArXiv") ############################################################################################################################# # Page title and favicon. st.set_page_config(page_title="Visual Question Answering", page_icon="❓") ############################################################################################################################# # Load the Salesforce/blip model directly. if torch.cuda.is_available(): device = torch.device("cuda") # elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # device = torch.device("mps") else: device = torch.device("cpu") blip_processor_base, blip_model_base = load_model_blip() blip_model_base.to(device) ############################################################################################################################# # Main function to create the Streamlit web application. # # 5 MODEL INFERENCES. # User Input = Image + Question About The Image. # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 def main(): try: ##################################################################################################################### # Load CSS. local_css("styles/style.css") ##################################################################################################################### # Title. title = f"""
Check out our GitHub repository
""", unsafe_allow_html=True, ) ############################################################################################################################# if __name__ == "__main__": main()