############################################################################################################################# | |
# Filename : app.py | |
# Description: A Streamlit application to utilize five models back to back | |
# Models used: | |
# 1. Visual Question Answering (VQA). | |
# 2. Fill-Mask. | |
# 3. Text2text Generation. | |
# 4. Text Generation. | |
# 5. Topic. | |
# Author : Georgios Ioannou | |
# | |
# Copyright © 2024 by Georgios Ioannou | |
############################################################################################################################# | |
# Import libraries. | |
import streamlit as st # Build the GUI of the application. | |
import torch # Load Salesforce/blip model(s) on GPU. | |
from bertopic import BERTopic # Topic model inference. | |
from PIL import Image # Open and identify a given image file. | |
from transformers import ( | |
pipeline, | |
BlipProcessor, | |
BlipForQuestionAnswering, | |
) # VQA model inference. | |
############################################################################################################################# | |
# Function to apply local CSS. | |
def local_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
############################################################################################################################# | |
# Model 1. | |
# Model 1 gets input from the user. | |
# User -> Model 1 | |
# Load the Visual Question Answering (VQA) model directly. | |
# Using transformers. | |
def load_model_blip(): | |
blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
blip_model_base = BlipForQuestionAnswering.from_pretrained( | |
"Salesforce/blip-vqa-base" | |
) | |
# Backup model. | |
# blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
# return blip_processor_large, blip_model_large | |
return blip_processor_base, blip_model_base | |
# General function for any Salesforce/blip model(s). | |
# VQA model. | |
def generate_answer_blip(processor, model, image, question): | |
# Prepare image + question. | |
inputs = processor(images=image, text=question, return_tensors="pt") | |
generated_ids = model.generate(**inputs, max_length=50) | |
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_answer | |
# Generate answer from the Salesforce/blip model(s). | |
# VQA model. | |
def generate_answer(image, question): | |
answer_blip_base = generate_answer_blip( | |
processor=blip_processor_base, | |
model=blip_model_base, | |
image=image, | |
question=question, | |
) | |
# answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question) | |
# return answer_blip_large | |
return answer_blip_base | |
############################################################################################################################# | |
# Model 2. | |
# Model 2 gets input from Model 1. | |
# User -> Model 1 -> Model 2 | |
def load_model_fill_mask(): | |
return pipeline(task="fill-mask", model="bert-base-uncased") | |
############################################################################################################################# | |
# Model 3. | |
# Model 3 gets input from Model 2. | |
# User -> Model 1 -> Model 2 -> Model 3 | |
def load_model_text2text_generation(): | |
return pipeline( | |
task="text2text-generation", model="facebook/blenderbot-400M-distill" | |
) | |
############################################################################################################################# | |
# Model 4. | |
# Model 4 gets input from Model 3. | |
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 | |
def load_model_fill_text_generation(): | |
return pipeline(task="text-generation", model="gpt2") | |
############################################################################################################################# | |
# Model 5. | |
# Model 5 gets input from Model 4. | |
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 | |
def load_model_bertopic1(): | |
return BERTopic.load(path="davanstrien/chat_topics") | |
def load_model_bertopic2(): | |
return BERTopic.load(path="MaartenGr/BERTopic_ArXiv") | |
############################################################################################################################# | |
# Page title and favicon. | |
st.set_page_config(page_title="Visual Question Answering", page_icon="❓") | |
############################################################################################################################# | |
# Load the Salesforce/blip model directly. | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
# elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
# device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
blip_processor_base, blip_model_base = load_model_blip() | |
blip_model_base.to(device) | |
############################################################################################################################# | |
# Main function to create the Streamlit web application. | |
# | |
# 5 MODEL INFERENCES. | |
# User Input = Image + Question About The Image. | |
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 | |
def main(): | |
try: | |
##################################################################################################################### | |
# Load CSS. | |
local_css("styles/style.css") | |
##################################################################################################################### | |
# Title. | |
title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> | |
Georgios Ioannou's Visual Question Answering</h1>""" | |
st.markdown(title, unsafe_allow_html=True) | |
# st.title("ChefBot - Automated Recipe Assistant") | |
##################################################################################################################### | |
# Subtitle. | |
subtitle = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> | |
CUNY Tech Prep Tutorial 4</h2>""" | |
st.markdown(subtitle, unsafe_allow_html=True) | |
##################################################################################################################### | |
# Image. | |
image = "./ctp.png" | |
left_co, cent_co, last_co = st.columns(3) | |
with cent_co: | |
st.image(image=image) | |
##################################################################################################################### | |
# User input (Image). | |
image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if image is not None: | |
bytes_data = image.getvalue() | |
with open(image.name, "wb") as file: | |
file.write(bytes_data) | |
st.image(image, caption="Uploaded Image.", use_column_width=True) | |
raw_image = Image.open(image.name).convert("RGB") | |
# User input (Question). | |
question = st.text_input("What's your question?") | |
############################################################################################################# | |
if question != "": | |
# Model 1. | |
with st.spinner( | |
text="VQA inference..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
answer = generate_answer(raw_image, question)[0] | |
st.success(f"VQA: {answer}") | |
bbu_pipeline = load_model_fill_mask() | |
text = ( | |
"I love " + answer + " and I would like to know how to [MASK]." | |
) | |
######################################################################################################### | |
# Model 2. | |
with st.spinner( | |
text="Fill-Mask inference..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
bbu_pipeline_output = bbu_pipeline(text) | |
bbu_output = bbu_pipeline_output[0]["sequence"] | |
st.success(f"Fill-Mask: {bbu_output}") | |
facebook_pipeline = load_model_text2text_generation() | |
utterance = bbu_output | |
######################################################################################################### | |
# Model 3. | |
with st.spinner( | |
text="Text2text Generation inference..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
facebook_pipeline_output = facebook_pipeline(utterance) | |
facebook_output = facebook_pipeline_output[0]["generated_text"] | |
st.success(f"Text2text Generation: {facebook_output}") | |
gpt2_pipeline = load_model_fill_text_generation() | |
######################################################################################################### | |
# Model 4. | |
with st.spinner( | |
text="Fill Text Generation inference..." | |
): # Spinner to keep the application interactive. | |
# Model inference. | |
gpt2_pipeline_output = gpt2_pipeline(facebook_output) | |
gpt2_output = gpt2_pipeline_output[0]["generated_text"] | |
st.success(f"Fill Text Generation: {gpt2_output}") | |
######################################################################################################### | |
# Model 5. | |
topic_model_1 = load_model_bertopic1() | |
topic, prob = topic_model_1.transform(gpt2_pipeline_output) | |
topic_model_1_output = topic_model_1.get_topic_info(topic[0])[ | |
"Representation" | |
][0] | |
st.success( | |
f"Topic(s) from davanstrien/chat_topics: {topic_model_1_output}" | |
) | |
topic_model_2 = load_model_bertopic2() | |
topic, prob = topic_model_2.transform(gpt2_pipeline_output) | |
topic_model_2_output = topic_model_2.get_topic_info(topic[0])[ | |
"Representation" | |
][0] | |
st.success( | |
f"Topic(s) from MaartenGr/BERTopic_ArXiv: {topic_model_1_output}" | |
) | |
except Exception as e: | |
# General exception/error handling. | |
st.error(e) | |
# GitHub repository of author. | |
st.markdown( | |
f""" | |
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our | |
<a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
############################################################################################################################# | |
if __name__ == "__main__": | |
main() | |