Update app.py
Browse files
app.py
CHANGED
@@ -8,10 +8,10 @@ from PIL import Image
|
|
8 |
|
9 |
|
10 |
class _MLPVectorProjector(nn.Module):
|
11 |
-
def
|
12 |
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
|
13 |
):
|
14 |
-
super(_MLPVectorProjector, self).
|
15 |
self.mlps = nn.ModuleList()
|
16 |
for _ in range(width):
|
17 |
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
|
@@ -59,8 +59,13 @@ def encode_image(image_path):
|
|
59 |
return img_embedding
|
60 |
|
61 |
#Get the projection model
|
|
|
|
|
62 |
|
63 |
#Get the fine-tuned phi-2 model
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
def example_inference(input_text, count): #, image, img_qn, audio):
|
@@ -87,6 +92,7 @@ def textMode(text, count):
|
|
87 |
|
88 |
def imageMode(image, question):
|
89 |
image_embedding = encode_image(image)
|
|
|
90 |
return "In progress"
|
91 |
|
92 |
def audioMode(audio):
|
@@ -120,7 +126,7 @@ with gr.Blocks() as demo:
|
|
120 |
text_output = gr.Textbox(label="Chat GPT like text")
|
121 |
with gr.Tab("Image mode"):
|
122 |
with gr.Row():
|
123 |
-
image_input = gr.Image()
|
124 |
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
|
125 |
image_button = gr.Button("Submit")
|
126 |
image_text_output = gr.Textbox(label="Answer")
|
|
|
8 |
|
9 |
|
10 |
class _MLPVectorProjector(nn.Module):
|
11 |
+
def init(
|
12 |
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
|
13 |
):
|
14 |
+
super(_MLPVectorProjector, self).init()
|
15 |
self.mlps = nn.ModuleList()
|
16 |
for _ in range(width):
|
17 |
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
|
|
|
59 |
return img_embedding
|
60 |
|
61 |
#Get the projection model
|
62 |
+
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cuda")
|
63 |
+
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth'))
|
64 |
|
65 |
#Get the fine-tuned phi-2 model
|
66 |
+
phi2_finetuned = AutoModelForCausalLM.from_pretrained(
|
67 |
+
"phi2_adaptor_fineTuned", trust_remote_code=True,
|
68 |
+
torch_dtype = torch.float32).to("cuda")
|
69 |
|
70 |
|
71 |
def example_inference(input_text, count): #, image, img_qn, audio):
|
|
|
92 |
|
93 |
def imageMode(image, question):
|
94 |
image_embedding = encode_image(image)
|
95 |
+
imgToTextEmb = img_proj_head(image_embedding)
|
96 |
return "In progress"
|
97 |
|
98 |
def audioMode(audio):
|
|
|
126 |
text_output = gr.Textbox(label="Chat GPT like text")
|
127 |
with gr.Tab("Image mode"):
|
128 |
with gr.Row():
|
129 |
+
image_input = gr.Image(type="filepath")
|
130 |
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
|
131 |
image_button = gr.Button("Submit")
|
132 |
image_text_output = gr.Textbox(label="Answer")
|