GeorgiosIoannouCoder's picture
Create app.py
c7742ac verified
raw
history blame
11.9 kB
#############################################################################################################################
# Filename : app.py
# Description: A Streamlit application to utilize five models back to back
# Models used:
# 1. Visual Question Answering (VQA).
# 2. Fill-Mask.
# 3. Text2text Generation.
# 4. Text Generation.
# 5. Topic.
# Author : Georgios Ioannou
#
# Copyright © 2024 by Georgios Ioannou
#############################################################################################################################
# Import libraries.
import streamlit as st # Build the GUI of the application.
import torch # Load Salesforce/blip model(s) on GPU.
from bertopic import BERTopic # Topic model inference.
from PIL import Image # Open and identify a given image file.
from transformers import (
pipeline,
BlipProcessor,
BlipForQuestionAnswering,
) # VQA model inference.
#############################################################################################################################
# Function to apply local CSS.
def local_css(file_name):
with open(file_name) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
#############################################################################################################################
# Model 1.
# Model 1 gets input from the user.
# User -> Model 1
# Load the Visual Question Answering (VQA) model directly.
# Using transformers.
@st.cache_resource
def load_model_blip():
blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base"
)
# Backup model.
# blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# return blip_processor_large, blip_model_large
return blip_processor_base, blip_model_base
# General function for any Salesforce/blip model(s).
# VQA model.
def generate_answer_blip(processor, model, image, question):
# Prepare image + question.
inputs = processor(images=image, text=question, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=50)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer
# Generate answer from the Salesforce/blip model(s).
# VQA model.
@st.cache_resource
def generate_answer(image, question):
answer_blip_base = generate_answer_blip(
processor=blip_processor_base,
model=blip_model_base,
image=image,
question=question,
)
# answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)
# return answer_blip_large
return answer_blip_base
#############################################################################################################################
# Model 2.
# Model 2 gets input from Model 1.
# User -> Model 1 -> Model 2
@st.cache_resource
def load_model_fill_mask():
return pipeline(task="fill-mask", model="bert-base-uncased")
#############################################################################################################################
# Model 3.
# Model 3 gets input from Model 2.
# User -> Model 1 -> Model 2 -> Model 3
@st.cache_resource
def load_model_text2text_generation():
return pipeline(
task="text2text-generation", model="facebook/blenderbot-400M-distill"
)
#############################################################################################################################
# Model 4.
# Model 4 gets input from Model 3.
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4
@st.cache_resource
def load_model_fill_text_generation():
return pipeline(task="text-generation", model="gpt2")
#############################################################################################################################
# Model 5.
# Model 5 gets input from Model 4.
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5
@st.cache_resource
def load_model_bertopic1():
return BERTopic.load(path="davanstrien/chat_topics")
@st.cache_resource
def load_model_bertopic2():
return BERTopic.load(path="MaartenGr/BERTopic_ArXiv")
#############################################################################################################################
# Page title and favicon.
st.set_page_config(page_title="Visual Question Answering", page_icon="❓")
#############################################################################################################################
# Load the Salesforce/blip model directly.
if torch.cuda.is_available():
device = torch.device("cuda")
# elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# device = torch.device("mps")
else:
device = torch.device("cpu")
blip_processor_base, blip_model_base = load_model_blip()
blip_model_base.to(device)
#############################################################################################################################
# Main function to create the Streamlit web application.
#
# 5 MODEL INFERENCES.
# User Input = Image + Question About The Image.
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5
def main():
try:
#####################################################################################################################
# Load CSS.
local_css("styles/style.css")
#####################################################################################################################
# Title.
title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
Georgios Ioannou's Visual Question Answering</h1>"""
st.markdown(title, unsafe_allow_html=True)
# st.title("ChefBot - Automated Recipe Assistant")
#####################################################################################################################
# Subtitle.
subtitle = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
CUNY Tech Prep Tutorial 4</h2>"""
st.markdown(subtitle, unsafe_allow_html=True)
#####################################################################################################################
# Image.
image = "./ctp.png"
left_co, cent_co, last_co = st.columns(3)
with cent_co:
st.image(image=image)
#####################################################################################################################
# User input (Image).
image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if image is not None:
bytes_data = image.getvalue()
with open(image.name, "wb") as file:
file.write(bytes_data)
st.image(image, caption="Uploaded Image.", use_column_width=True)
raw_image = Image.open(image.name).convert("RGB")
# User input (Question).
question = st.text_input("What's your question?")
#############################################################################################################
if question != "":
# Model 1.
with st.spinner(
text="VQA inference..."
): # Spinner to keep the application interactive.
# Model inference.
answer = generate_answer(raw_image, question)[0]
st.success(f"VQA: {answer}")
bbu_pipeline = load_model_fill_mask()
text = (
"I love " + answer + " and I would like to know how to [MASK]."
)
#########################################################################################################
# Model 2.
with st.spinner(
text="Fill-Mask inference..."
): # Spinner to keep the application interactive.
# Model inference.
bbu_pipeline_output = bbu_pipeline(text)
bbu_output = bbu_pipeline_output[0]["sequence"]
st.success(f"Fill-Mask: {bbu_output}")
facebook_pipeline = load_model_text2text_generation()
utterance = bbu_output
#########################################################################################################
# Model 3.
with st.spinner(
text="Text2text Generation inference..."
): # Spinner to keep the application interactive.
# Model inference.
facebook_pipeline_output = facebook_pipeline(utterance)
facebook_output = facebook_pipeline_output[0]["generated_text"]
st.success(f"Text2text Generation: {facebook_output}")
gpt2_pipeline = load_model_fill_text_generation()
#########################################################################################################
# Model 4.
with st.spinner(
text="Fill Text Generation inference..."
): # Spinner to keep the application interactive.
# Model inference.
gpt2_pipeline_output = gpt2_pipeline(facebook_output)
gpt2_output = gpt2_pipeline_output[0]["generated_text"]
st.success(f"Fill Text Generation: {gpt2_output}")
#########################################################################################################
# Model 5.
topic_model_1 = load_model_bertopic1()
topic, prob = topic_model_1.transform(gpt2_pipeline_output)
topic_model_1_output = topic_model_1.get_topic_info(topic[0])[
"Representation"
][0]
st.success(
f"Topic(s) from davanstrien/chat_topics: {topic_model_1_output}"
)
topic_model_2 = load_model_bertopic2()
topic, prob = topic_model_2.transform(gpt2_pipeline_output)
topic_model_2_output = topic_model_2.get_topic_info(topic[0])[
"Representation"
][0]
st.success(
f"Topic(s) from MaartenGr/BERTopic_ArXiv: {topic_model_1_output}"
)
except Exception as e:
# General exception/error handling.
st.error(e)
# GitHub repository of author.
st.markdown(
f"""
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
<a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
</p>
""",
unsafe_allow_html=True,
)
#############################################################################################################################
if __name__ == "__main__":
main()