huangjy-pku commited on
Commit
d37fab3
β€’
1 Parent(s): 7978a78

fix hf logging

Browse files
Files changed (3) hide show
  1. model/cfg.yaml β†’ cfg.yaml +2 -2
  2. model/leo_agent.py +10 -13
  3. utils.py +50 -37
model/cfg.yaml β†’ cfg.yaml RENAMED
@@ -1,4 +1,5 @@
1
- use_ckpt: hf
 
2
  hf_ckpt_path: [huangjy-pku/embodied-generalist, weights/leo_noact_hf.pth]
3
  local_ckpt_path: /mnt/huangjiangyong/leo/hf_assets/weights/leo_noact_lora.pth
4
  model:
@@ -6,7 +7,6 @@ model:
6
  # vision modules omitted
7
  llm:
8
  name: Vicuna7B
9
- use_ckpt: hf
10
  hf_cfg_path: huangjy-pku/vicuna-7b
11
  local_cfg_path: /mnt/huangjiangyong/vicuna-7b
12
  truncation_side: right
 
1
+ launch_mode: hf # hf or local
2
+ hf_log_path: embodied-generalist/leo_demo_log
3
  hf_ckpt_path: [huangjy-pku/embodied-generalist, weights/leo_noact_hf.pth]
4
  local_ckpt_path: /mnt/huangjiangyong/leo/hf_assets/weights/leo_noact_lora.pth
5
  model:
 
7
  # vision modules omitted
8
  llm:
9
  name: Vicuna7B
 
10
  hf_cfg_path: huangjy-pku/vicuna-7b
11
  local_cfg_path: /mnt/huangjiangyong/vicuna-7b
12
  truncation_side: right
model/leo_agent.py CHANGED
@@ -14,16 +14,13 @@ def disabled_train(self, mode=True):
14
  class LeoAgentLLM(nn.Module):
15
  def __init__(self, cfg):
16
  super().__init__()
17
- if hasattr(cfg, 'model'):
18
- cfg = cfg.model
19
-
20
  # LLM
21
- if cfg.llm.use_ckpt == 'hf':
22
- llm_cfg_path = snapshot_download(cfg.llm.hf_cfg_path)
23
  else:
24
- llm_cfg_path = cfg.llm.local_cfg_path
25
  self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False,
26
- truncation_side=cfg.llm.truncation_side)
27
  self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
28
  self.llm_tokenizer.add_special_tokens({'bos_token': '<s>'})
29
  self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
@@ -37,18 +34,18 @@ class LeoAgentLLM(nn.Module):
37
  self.llm_model.train = disabled_train
38
 
39
  # LoRA-based LLM fine-tuning
40
- if cfg.llm.lora.flag:
41
  lora_config = LoraConfig(
42
- r=cfg.llm.lora.rank,
43
- lora_alpha=cfg.llm.lora.alpha,
44
- target_modules=cfg.llm.lora.target_modules,
45
- lora_dropout=cfg.llm.lora.dropout,
46
  bias='none',
47
  modules_to_save=[],
48
  )
49
  self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config)
50
 
51
- self.max_context_len = cfg.llm.max_context_len
52
 
53
  @property
54
  def device(self):
 
14
  class LeoAgentLLM(nn.Module):
15
  def __init__(self, cfg):
16
  super().__init__()
 
 
 
17
  # LLM
18
+ if cfg.launch_mode == 'hf':
19
+ llm_cfg_path = snapshot_download(cfg.model.llm.hf_cfg_path)
20
  else:
21
+ llm_cfg_path = cfg.model.llm.local_cfg_path
22
  self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False,
23
+ truncation_side=cfg.model.llm.truncation_side)
24
  self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
25
  self.llm_tokenizer.add_special_tokens({'bos_token': '<s>'})
26
  self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
 
34
  self.llm_model.train = disabled_train
35
 
36
  # LoRA-based LLM fine-tuning
37
+ if cfg.model.llm.lora.flag:
38
  lora_config = LoraConfig(
39
+ r=cfg.model.llm.lora.rank,
40
+ lora_alpha=cfg.model.llm.lora.alpha,
41
+ target_modules=cfg.model.llm.lora.target_modules,
42
+ lora_dropout=cfg.model.llm.lora.dropout,
43
  bias='none',
44
  modules_to_save=[],
45
  )
46
  self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config)
47
 
48
+ self.max_context_len = cfg.model.llm.max_context_len
49
 
50
  @property
51
  def device(self):
utils.py CHANGED
@@ -3,11 +3,12 @@ import datetime
3
  import json
4
  import os
5
  import time
 
6
 
7
  import gradio as gr
8
  import torch
9
  import yaml
10
- from huggingface_hub import hf_hub_download
11
  from omegaconf import OmegaConf
12
 
13
  from model.leo_agent import LeoAgentLLM
@@ -27,33 +28,34 @@ OBJECTS_PROMPT = "Objects (including you) in the scene:"
27
  TASK_PROMPT = "USER: {instruction} ASSISTANT:"
28
  OBJ_FEATS_DIR = 'assets/obj_features'
29
 
30
-
31
- def load_agent():
32
- # build model
33
- with open('model/cfg.yaml') as f:
34
- cfg = yaml.safe_load(f)
35
- cfg = OmegaConf.create(cfg)
36
- agent = LeoAgentLLM(cfg)
37
-
38
- # load checkpoint
39
- if cfg.use_ckpt == 'hf':
40
- ckpt_path = hf_hub_download(cfg.hf_ckpt_path[0], cfg.hf_ckpt_path[1])
41
- else:
42
- ckpt_path = cfg.local_ckpt_path
43
- ckpt = torch.load(ckpt_path, map_location='cpu')
44
- agent.load_state_dict(ckpt, strict=False)
45
-
46
- agent.eval()
47
- agent.to(DEVICE)
48
- return agent
49
-
50
- agent = load_agent()
51
-
52
-
53
- def get_log_fname():
54
- t = datetime.datetime.now()
55
- fname = os.path.join(LOG_DIR, f'{t.year}-{t.month:02d}-{t.day:02d}.json')
56
- return fname
 
57
 
58
 
59
  def change_scene(dropdown_scene: str):
@@ -139,17 +141,28 @@ def vote_response(
139
  'type': vote_type,
140
  'scene': dropdown_scene,
141
  'mode': dropdown_conversation_mode,
142
- 'dialogue': chatbot,
143
  }
144
- fname = get_log_fname()
145
- if os.path.exists(fname):
146
- with open(fname) as f:
147
- logs = json.load(f)
148
- logs.append(this_log)
 
 
 
 
 
 
149
  else:
150
- logs = [this_log]
151
- with open(fname, 'w') as f:
152
- json.dump(logs, f, indent=2)
 
 
 
 
 
153
 
154
 
155
  def upvote_response(
 
3
  import json
4
  import os
5
  import time
6
+ from uuid import uuid4
7
 
8
  import gradio as gr
9
  import torch
10
  import yaml
11
+ from huggingface_hub import CommitScheduler, hf_hub_download
12
  from omegaconf import OmegaConf
13
 
14
  from model.leo_agent import LeoAgentLLM
 
28
  TASK_PROMPT = "USER: {instruction} ASSISTANT:"
29
  OBJ_FEATS_DIR = 'assets/obj_features'
30
 
31
+ with open('cfg.yaml') as f:
32
+ cfg = yaml.safe_load(f)
33
+ cfg = OmegaConf.create(cfg)
34
+
35
+ # build model
36
+ agent = LeoAgentLLM(cfg)
37
+
38
+ # load checkpoint
39
+ if cfg.launch_mode == 'hf':
40
+ ckpt_path = hf_hub_download(cfg.hf_ckpt_path[0], cfg.hf_ckpt_path[1])
41
+ else:
42
+ ckpt_path = cfg.local_ckpt_path
43
+ ckpt = torch.load(ckpt_path, map_location='cpu')
44
+ agent.load_state_dict(ckpt, strict=False)
45
+ agent.eval()
46
+ agent.to(DEVICE)
47
+
48
+ os.makedirs(LOG_DIR, exist_ok=True)
49
+ t = datetime.datetime.now()
50
+ log_fname = os.path.join(LOG_DIR, f'{t.year}-{t.month:02d}-{t.day:02d}-{uuid4()}.json')
51
+
52
+ if cfg.launch_mode == 'hf':
53
+ scheduler = CommitScheduler(
54
+ repo_id=cfg.hf_log_path,
55
+ repo_type='dataset',
56
+ folder_path=LOG_DIR,
57
+ path_in_repo=LOG_DIR,
58
+ )
59
 
60
 
61
  def change_scene(dropdown_scene: str):
 
141
  'type': vote_type,
142
  'scene': dropdown_scene,
143
  'mode': dropdown_conversation_mode,
144
+ 'dialogue': [chatbot[-1]] if 'Single-round' in dropdown_conversation_mode else chatbot,
145
  }
146
+
147
+ if cfg.launch_mode == 'hf':
148
+ with scheduler.lock: # use scheduler
149
+ if os.path.exists(log_fname):
150
+ with open(log_fname) as f:
151
+ logs = json.load(f)
152
+ logs.append(this_log)
153
+ else:
154
+ logs = [this_log]
155
+ with open(log_fname, 'w') as f:
156
+ json.dump(logs, f, indent=2)
157
  else:
158
+ if os.path.exists(log_fname):
159
+ with open(log_fname) as f:
160
+ logs = json.load(f)
161
+ logs.append(this_log)
162
+ else:
163
+ logs = [this_log]
164
+ with open(log_fname, 'w') as f:
165
+ json.dump(logs, f, indent=2)
166
 
167
 
168
  def upvote_response(