Upyaya's picture
Removed use of varaible "init_model_required" to init model.
d6e285e
raw
history blame
1.76 kB
from transformers import Blip2ForConditionalGeneration
from transformers import Blip2Processor
from peft import PeftModel
import streamlit as st
from PIL import Image
import torch
preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image
base_model_ckp = "/model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
peft_model_ckp = "/model/blip2_peft" #PEFT model checkpoint path
#init_model_required = True
processor = None
model = None
def init_model():
#if init_model_required:
#Preprocess input
processor = Blip2Processor.from_pretrained(preprocess_ckp)
#Model
model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto")
model = PeftModel.from_pretrained(model, peft_model_ckp)
#init_model_required = False
def main():
st.title("Fashion Image Caption using BLIP2")
init_model()
file_name = st.file_uploader("Upload image")
if file_name is not None:
image_col, caption_text = st.columns(2)
image_col.header("Image")
image = Image.open(file_name)
image_col.image(image, use_column_width = True)
#Preprocess the image
inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16)
pixel_values = inputs.pixel_values
#Predict the caption for the imahe
generated_ids = model.generate(pixel_values = pixel_values, max_length = 25)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
#Output the predict text
caption_text.header("Generated Caption")
caption_text.text(generated_caption)
if __name__ == "__main__":
main()