Fixed the issue to load the model for inferance in CPU device
Browse filesThe model is trained on GPU, with bitsandbytes, peft. But bitsandbytes does work only on GPU devices. So modify the inti model and input dtype to work on CPU
app.py
CHANGED
@@ -4,25 +4,32 @@ from peft import PeftModel
|
|
4 |
import streamlit as st
|
5 |
from PIL import Image
|
6 |
import torch
|
|
|
7 |
|
8 |
preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image
|
9 |
base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
|
10 |
peft_model_ckp = "./model/blip2_peft" #PEFT model checkpoint path
|
11 |
-
|
|
|
12 |
#init_model_required = True
|
13 |
-
processor = None
|
14 |
-
model = None
|
15 |
|
16 |
-
def init_model():
|
17 |
|
18 |
#if init_model_required:
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
26 |
|
27 |
#init_model_required = False
|
28 |
|
@@ -32,10 +39,16 @@ def main():
|
|
32 |
|
33 |
st.title("Fashion Image Caption using BLIP2")
|
34 |
|
35 |
-
init_model()
|
36 |
|
|
|
|
|
37 |
file_name = st.file_uploader("Upload image")
|
38 |
|
|
|
|
|
|
|
|
|
39 |
if file_name is not None:
|
40 |
|
41 |
image_col, caption_text = st.columns(2)
|
@@ -45,7 +58,12 @@ def main():
|
|
45 |
image_col.image(image, use_column_width = True)
|
46 |
|
47 |
#Preprocess the image
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
pixel_values = inputs.pixel_values
|
50 |
|
51 |
#Predict the caption for the imahe
|
@@ -56,6 +74,5 @@ def main():
|
|
56 |
caption_text.header("Generated Caption")
|
57 |
caption_text.text(generated_caption)
|
58 |
|
59 |
-
|
60 |
if __name__ == "__main__":
|
61 |
main()
|
|
|
4 |
import streamlit as st
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
+
import os
|
8 |
|
9 |
preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image
|
10 |
base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
|
11 |
peft_model_ckp = "./model/blip2_peft" #PEFT model checkpoint path
|
12 |
+
sample_img_path = "./sample_images/"
|
13 |
+
|
14 |
#init_model_required = True
|
15 |
+
#processor = None
|
16 |
+
#model = None
|
17 |
|
18 |
+
#def init_model():
|
19 |
|
20 |
#if init_model_required:
|
21 |
|
22 |
+
#Preprocess input
|
23 |
+
processor = Blip2Processor.from_pretrained(preprocess_ckp)
|
24 |
+
|
25 |
+
#Model
|
26 |
+
#Inferance on GPU device. Will give error in CPU system, as "load_in_8bit" is an setting of bitsandbytes library and only works for GPU
|
27 |
+
#model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto")
|
28 |
|
29 |
+
#Inferance on CPU device
|
30 |
+
model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp)
|
31 |
+
|
32 |
+
model = PeftModel.from_pretrained(model, peft_model_ckp)
|
33 |
|
34 |
#init_model_required = False
|
35 |
|
|
|
39 |
|
40 |
st.title("Fashion Image Caption using BLIP2")
|
41 |
|
42 |
+
#init_model()
|
43 |
|
44 |
+
#Select few sample images for the catagory of cloths
|
45 |
+
option = st.selectbox('Sample images ?', ('cap', 'tee', 'dress'))
|
46 |
file_name = st.file_uploader("Upload image")
|
47 |
|
48 |
+
if file_name is None and option is not None:
|
49 |
+
|
50 |
+
file_name = os.join.path(sample_img_path, option)
|
51 |
+
|
52 |
if file_name is not None:
|
53 |
|
54 |
image_col, caption_text = st.columns(2)
|
|
|
58 |
image_col.image(image, use_column_width = True)
|
59 |
|
60 |
#Preprocess the image
|
61 |
+
#Inferance on GPU. When used this on GPU will get errors like: "slow_conv2d_cpu" not implemented for 'Half'" , " Input type (float) and bias type (struct c10::Half)"
|
62 |
+
#inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16)
|
63 |
+
|
64 |
+
#Inferance on CPU
|
65 |
+
inputs = processor(images = image, return_tensors = "pt")
|
66 |
+
|
67 |
pixel_values = inputs.pixel_values
|
68 |
|
69 |
#Predict the caption for the imahe
|
|
|
74 |
caption_text.header("Generated Caption")
|
75 |
caption_text.text(generated_caption)
|
76 |
|
|
|
77 |
if __name__ == "__main__":
|
78 |
main()
|