cwkuo commited on
Commit
d8c6a57
·
1 Parent(s): 6855619

reduce GPU memory by keeping necessary modules of query_enc

Browse files
Files changed (2) hide show
  1. app.py +9 -6
  2. model/gptk.py +1 -1
app.py CHANGED
@@ -4,8 +4,10 @@ import time
4
  import gradio as gr
5
  import requests
6
  import numpy as np
 
7
 
8
  import torch
 
9
  import open_clip
10
  import faiss
11
  from transformers import TextIteratorStreamer
@@ -96,7 +98,7 @@ def add_text(state: Conversation, text, image):
96
  def search(image, pos, topk, knwl_db, knwl_idx):
97
  with torch.cuda.amp.autocast():
98
  image = query_trans(image).unsqueeze(0).to(device)
99
- query = query_enc.encode_image(image, normalize=True)
100
  query = query.cpu().numpy()
101
 
102
  _, I = knwl_idx.search(query, 4*topk)
@@ -372,15 +374,16 @@ def build_knowledge():
372
  "act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
373
  "attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
374
  }
 
375
 
376
- return knwl_db
377
 
378
 
379
  def build_query_model():
380
  query_enc, _, query_trans = open_clip.create_model_and_transforms(
381
- "ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16', device=device
382
  )
383
- query_enc = query_enc.eval()
384
 
385
  return query_enc, query_trans
386
 
@@ -388,7 +391,7 @@ def build_query_model():
388
  def build_gptk_model():
389
  _, gptk_trans = get_gptk_image_transform()
390
  topk = {"whole": 60, "five": 24, "nine": 16}
391
- gptk_model = get_gptk_model(d_knwl=1024, topk=topk)
392
  gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
393
  gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
394
  gptk_model.load_state_dict(gptk_ckpt, strict=False)
@@ -402,8 +405,8 @@ if torch.cuda.is_available():
402
  else:
403
  device = torch.device("cpu")
404
 
 
405
  gptk_model, gptk_trans, topk = build_gptk_model()
406
  query_enc, query_trans = build_query_model()
407
- knwl_db = build_knowledge()
408
  demo = build_demo()
409
  demo.queue().launch()
 
4
  import gradio as gr
5
  import requests
6
  import numpy as np
7
+ from pathlib import Path
8
 
9
  import torch
10
+ import torch.nn.functional as F
11
  import open_clip
12
  import faiss
13
  from transformers import TextIteratorStreamer
 
98
  def search(image, pos, topk, knwl_db, knwl_idx):
99
  with torch.cuda.amp.autocast():
100
  image = query_trans(image).unsqueeze(0).to(device)
101
+ query = F.normalize(query_enc(image), dim=-1)
102
  query = query.cpu().numpy()
103
 
104
  _, I = knwl_idx.search(query, 4*topk)
 
374
  "act": get_knwl('knowledge/(dataset-action)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
375
  "attr": get_knwl('knowledge/(dataset-attribute)(clip-model-ViT-g-14)(dbscan)(eps-0.15)(ms-1)'),
376
  }
377
+ d_knwl = knwl_db["obj"][0].feature.shape[-1]
378
 
379
+ return knwl_db, d_knwl
380
 
381
 
382
  def build_query_model():
383
  query_enc, _, query_trans = open_clip.create_model_and_transforms(
384
+ "ViT-g-14", pretrained="laion2b_s34b_b88k", precision='fp16'
385
  )
386
+ query_enc = query_enc.visual.to(device).eval()
387
 
388
  return query_enc, query_trans
389
 
 
391
  def build_gptk_model():
392
  _, gptk_trans = get_gptk_image_transform()
393
  topk = {"whole": 60, "five": 24, "nine": 16}
394
+ gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
395
  gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
396
  gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
397
  gptk_model.load_state_dict(gptk_ckpt, strict=False)
 
405
  else:
406
  device = torch.device("cpu")
407
 
408
+ knwl_db, d_knwl = build_knowledge()
409
  gptk_model, gptk_trans, topk = build_gptk_model()
410
  query_enc, query_trans = build_query_model()
 
411
  demo = build_demo()
412
  demo.queue().launch()
model/gptk.py CHANGED
@@ -49,7 +49,7 @@ class GPTK(nn.Module):
49
  llm_config.gradient_checkpointing = True
50
  llm_config.use_cache = True
51
  quantization_config = BitsAndBytesConfig(
52
- load_in_4bit=True,
53
  llm_int8_threshold=6.0,
54
  llm_int8_has_fp16_weight=False,
55
  bnb_4bit_compute_dtype=torch.float16,
 
49
  llm_config.gradient_checkpointing = True
50
  llm_config.use_cache = True
51
  quantization_config = BitsAndBytesConfig(
52
+ load_in_8bit=True,
53
  llm_int8_threshold=6.0,
54
  llm_int8_has_fp16_weight=False,
55
  bnb_4bit_compute_dtype=torch.float16,