Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
·
d8c6a57
1
Parent(s):
6855619
reduce GPU memory by keeping necessary modules of query_enc
Browse files- app.py +9 -6
- model/gptk.py +1 -1
app.py
CHANGED
@@ -4,8 +4,10 @@ import time
|
|
4 |
import gradio as gr
|
5 |
import requests
|
6 |
import numpy as np
|
|
|
7 |
|
8 |
import torch
|
|
|
9 |
import open_clip
|
10 |
import faiss
|
11 |
from transformers import TextIteratorStreamer
|
@@ -96,7 +98,7 @@ def add_text(state: Conversation, text, image):
|
|
96 |
def search(image, pos, topk, knwl_db, knwl_idx):
|
97 |
with torch.cuda.amp.autocast():
|
98 |
image = query_trans(image).unsqueeze(0).to(device)
|
99 |
-
query =
|
100 |
query = query.cpu().numpy()
|
101 |
|
102 |
_, I = knwl_idx.search(query, 4*topk)
|
@@ -372,15 +374,16 @@ def build_knowledge():
|
|
372 |
"act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
373 |
"attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
374 |
}
|
|
|
375 |
|
376 |
-
return knwl_db
|
377 |
|
378 |
|
379 |
def build_query_model():
|
380 |
query_enc, _, query_trans = open_clip.create_model_and_transforms(
|
381 |
-
"ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
|
382 |
)
|
383 |
-
query_enc = query_enc.eval()
|
384 |
|
385 |
return query_enc, query_trans
|
386 |
|
@@ -388,7 +391,7 @@ def build_query_model():
|
|
388 |
def build_gptk_model():
|
389 |
_, gptk_trans = get_gptk_image_transform()
|
390 |
topk = {"whole": 60, "five": 24, "nine": 16}
|
391 |
-
gptk_model = get_gptk_model(d_knwl=
|
392 |
gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
|
393 |
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
394 |
gptk_model.load_state_dict(gptk_ckpt, strict=False)
|
@@ -402,8 +405,8 @@ if torch.cuda.is_available():
|
|
402 |
else:
|
403 |
device = torch.device("cpu")
|
404 |
|
|
|
405 |
gptk_model, gptk_trans, topk = build_gptk_model()
|
406 |
query_enc, query_trans = build_query_model()
|
407 |
-
knwl_db = build_knowledge()
|
408 |
demo = build_demo()
|
409 |
demo.queue().launch()
|
|
|
4 |
import gradio as gr
|
5 |
import requests
|
6 |
import numpy as np
|
7 |
+
from pathlib import Path
|
8 |
|
9 |
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
import open_clip
|
12 |
import faiss
|
13 |
from transformers import TextIteratorStreamer
|
|
|
98 |
def search(image, pos, topk, knwl_db, knwl_idx):
|
99 |
with torch.cuda.amp.autocast():
|
100 |
image = query_trans(image).unsqueeze(0).to(device)
|
101 |
+
query = F.normalize(query_enc(image), dim=-1)
|
102 |
query = query.cpu().numpy()
|
103 |
|
104 |
_, I = knwl_idx.search(query, 4*topk)
|
|
|
374 |
"act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
375 |
"attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
|
376 |
}
|
377 |
+
d_knwl = knwl_db["obj"][0].feature.shape[-1]
|
378 |
|
379 |
+
return knwl_db, d_knwl
|
380 |
|
381 |
|
382 |
def build_query_model():
|
383 |
query_enc, _, query_trans = open_clip.create_model_and_transforms(
|
384 |
+
"ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
|
385 |
)
|
386 |
+
query_enc = query_enc.visual.to(device).eval()
|
387 |
|
388 |
return query_enc, query_trans
|
389 |
|
|
|
391 |
def build_gptk_model():
|
392 |
_, gptk_trans = get_gptk_image_transform()
|
393 |
topk = {"whole": 60, "five": 24, "nine": 16}
|
394 |
+
gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
|
395 |
gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
|
396 |
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
397 |
gptk_model.load_state_dict(gptk_ckpt, strict=False)
|
|
|
405 |
else:
|
406 |
device = torch.device("cpu")
|
407 |
|
408 |
+
knwl_db, d_knwl = build_knowledge()
|
409 |
gptk_model, gptk_trans, topk = build_gptk_model()
|
410 |
query_enc, query_trans = build_query_model()
|
|
|
411 |
demo = build_demo()
|
412 |
demo.queue().launch()
|
model/gptk.py
CHANGED
@@ -49,7 +49,7 @@ class GPTK(nn.Module):
|
|
49 |
llm_config.gradient_checkpointing = True
|
50 |
llm_config.use_cache = True
|
51 |
quantization_config = BitsAndBytesConfig(
|
52 |
-
|
53 |
llm_int8_threshold=6.0,
|
54 |
llm_int8_has_fp16_weight=False,
|
55 |
bnb_4bit_compute_dtype=torch.float16,
|
|
|
49 |
llm_config.gradient_checkpointing = True
|
50 |
llm_config.use_cache = True
|
51 |
quantization_config = BitsAndBytesConfig(
|
52 |
+
load_in_8bit=True,
|
53 |
llm_int8_threshold=6.0,
|
54 |
llm_int8_has_fp16_weight=False,
|
55 |
bnb_4bit_compute_dtype=torch.float16,
|