Spaces:
Runtime error
Runtime error
| from transformers import Blip2ForConditionalGeneration | |
| from transformers import Blip2Processor | |
| from peft import PeftModel | |
| import streamlit as st | |
| from PIL import Image | |
| #import torch | |
| import os | |
| preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image | |
| base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path | |
| peft_model_ckp = "./model/blip2_peft" #PEFT model checkpoint path | |
| sample_img_path = "./sample_images" | |
| map_sampleid_name = { | |
| 'dress' : '00fe223d-9d1f-4bd3-a556-7ece9d28e6fb.jpeg', | |
| 'earrings': '0b3862ae-f89e-419c-bc1e-57418abd4180.jpeg', | |
| 'sweater': '0c21ba7b-ceb6-4136-94a4-1d4394499986.jpeg', | |
| 'sunglasses': '0e44ec10-e53b-473a-a77f-ac8828bb5e01.jpeg', | |
| 'shoe': '4cd37d6d-e7ea-4c6e-aab2-af700e480bc1.jpeg', | |
| 'hat': '69aeb517-c66c-47b8-af7d-bdf1fde57ed0.jpeg', | |
| 'heels':'447abc42-6ac7-4458-a514-bdcd570b1cd1.jpeg', | |
| 'socks': 'd188836c-b734-4031-98e5-423d5ff1239d.jpeg', | |
| 'tee': 'e2d8637a-5478-429d-a2a8-3d5859dbc64d.jpeg', | |
| 'bracelet': 'e78518ac-0f54-4483-a233-fad6511f0b86.jpeg' | |
| } | |
| def init_model(init_model_required): | |
| if init_model_required: | |
| #Preprocess input | |
| processor = Blip2Processor.from_pretrained(preprocess_ckp) | |
| #Model | |
| #Inferance on GPU device. Will give error in CPU system, as "load_in_8bit" is an setting of bitsandbytes library and only works for GPU | |
| #model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto") | |
| #Inferance on CPU device | |
| model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp) | |
| model = PeftModel.from_pretrained(model, peft_model_ckp) | |
| init_model_required = False | |
| return processor, model, init_model_required | |
| #def main(): | |
| #Select few sample images for the catagory of cloths | |
| with st.form("app", clear_on_submit = True): | |
| st.caption("Select image:") | |
| option = 'None' | |
| option = st.selectbox('From sample', ('None', 'dress', 'earrings', 'sweater', 'sunglasses', 'shoe', 'hat', 'heels', 'socks', 'tee', 'bracelet'), index = 0) | |
| st.text("Or") | |
| file_name = None | |
| file_name = st.file_uploader(label = "Upload an image", accept_multiple_files = False) | |
| btn_click = st.form_submit_button('Generate') | |
| st.caption("Application deployed on CPU basic with 16GB RAM") | |
| if btn_click: | |
| image = None | |
| if file_name is not None: | |
| image = Image.open(file_name) | |
| elif option is not 'None': | |
| file_name = os.path.join(sample_img_path, map_sampleid_name[option]) | |
| image = Image.open(file_name) | |
| if image is not None: | |
| image_col, caption_text = st.columns(2) | |
| image_col.header("Image") | |
| caption_text.header("Generated Caption") | |
| image_col.image(image.resize((252,252)), use_column_width = True) | |
| caption_text.text("") | |
| if 'init_model_required' not in st.session_state: | |
| with st.spinner('Initializing model...'): | |
| init_model_required = True | |
| processor, model, init_model_required = init_model(init_model_required) | |
| #Save session init model in session state | |
| if 'init_model_required' not in st.session_state: | |
| st.session_state.init_model_required = init_model_required | |
| st.session_state.processor = processor | |
| st.session_state.model = model | |
| else: | |
| processor = st.session_state.processor | |
| model = st.session_state.model | |
| with st.spinner('Generating Caption...'): | |
| #Preprocess the image | |
| #Inferance on GPU. When used this on GPU will get errors like: "slow_conv2d_cpu" not implemented for 'Half'" , " Input type (float) and bias type (struct c10::Half)" | |
| #inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16) | |
| #Inferance on CPU | |
| inputs = processor(images = image, return_tensors = "pt") | |
| pixel_values = inputs.pixel_values | |
| #Predict the caption for the imahe | |
| generated_ids = model.generate(pixel_values = pixel_values, max_length = 10) | |
| generated_caption = processor.batch_decode(generated_ids, skip_special_tokens = True)[0] | |
| #Output the predict text | |
| caption_text.text(generated_caption) | |
| #if __name__ == "__main__": | |
| # main() |