update models
Browse files- app.py +0 -2
- model/__init__.py +0 -9
- model/agent.py +0 -76
- model/openlamm.py +3 -3
app.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
from transformers import AutoModel, AutoTokenizer
|
2 |
from copy import deepcopy
|
3 |
-
import os
|
4 |
-
import ipdb
|
5 |
import gradio as gr
|
6 |
import mdtex2html
|
7 |
from model.openlamm import LAMMPEFTModel
|
|
|
1 |
from transformers import AutoModel, AutoTokenizer
|
2 |
from copy import deepcopy
|
|
|
|
|
3 |
import gradio as gr
|
4 |
import mdtex2html
|
5 |
from model.openlamm import LAMMPEFTModel
|
model/__init__.py
CHANGED
@@ -1,10 +1 @@
|
|
1 |
-
# from .agent import DeepSpeedAgent
|
2 |
from .openlamm import LAMMPEFTModel
|
3 |
-
|
4 |
-
|
5 |
-
# def load_model(args):
|
6 |
-
# agent_name = args['models'][args['model']]['agent_name']
|
7 |
-
# model_name = args['models'][args['model']]['model_name']
|
8 |
-
# model = globals()[model_name](**args)
|
9 |
-
# agent = globals()[agent_name](model, args)
|
10 |
-
# return agent
|
|
|
|
|
1 |
from .openlamm import LAMMPEFTModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/agent.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
from header import *
|
2 |
-
from torch.utils.tensorboard import SummaryWriter
|
3 |
-
|
4 |
-
|
5 |
-
class DeepSpeedAgent:
|
6 |
-
|
7 |
-
def __init__(self, model, args):
|
8 |
-
super(DeepSpeedAgent, self).__init__()
|
9 |
-
self.args = args
|
10 |
-
self.model = model
|
11 |
-
self.writer = SummaryWriter(args['log_path'])
|
12 |
-
if args['stage'] == 2:
|
13 |
-
self.load_stage_1_parameters(args["delta_ckpt_path"])
|
14 |
-
print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}')
|
15 |
-
|
16 |
-
# load config parameters of deepspeed
|
17 |
-
ds_params = json.load(open(self.args['ds_config_path']))
|
18 |
-
ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
|
19 |
-
ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
|
20 |
-
self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
|
21 |
-
model=self.model,
|
22 |
-
model_parameters=self.model.parameters(),
|
23 |
-
config_params=ds_params,
|
24 |
-
dist_init_required=True,
|
25 |
-
args=types.SimpleNamespace(**args)
|
26 |
-
)
|
27 |
-
|
28 |
-
@torch.no_grad()
|
29 |
-
def predict(self, batch):
|
30 |
-
self.model.eval()
|
31 |
-
string = self.model.generate_one_sample(batch)
|
32 |
-
return string
|
33 |
-
|
34 |
-
def train_model(self, batch, current_step=0, pbar=None):
|
35 |
-
self.ds_engine.module.train()
|
36 |
-
loss, mle_acc = self.ds_engine(batch)
|
37 |
-
|
38 |
-
self.ds_engine.backward(loss)
|
39 |
-
self.ds_engine.step()
|
40 |
-
pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
|
41 |
-
pbar.update(1)
|
42 |
-
if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
|
43 |
-
elapsed = pbar.format_dict['elapsed']
|
44 |
-
rate = pbar.format_dict['rate']
|
45 |
-
remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
|
46 |
-
remaining = str(datetime.timedelta(seconds=remaining))
|
47 |
-
self.writer.add_scalar('train/loss', loss.item(), current_step)
|
48 |
-
self.writer.add_scalar('train/token_acc', mle_acc*100, current_step)
|
49 |
-
logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
|
50 |
-
|
51 |
-
mle_acc *= 100
|
52 |
-
return mle_acc
|
53 |
-
|
54 |
-
def save_model(self, path, current_step):
|
55 |
-
# only save trainable model parameters
|
56 |
-
param_grad_dic = {
|
57 |
-
k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
|
58 |
-
}
|
59 |
-
state_dict = self.ds_engine.module.state_dict()
|
60 |
-
checkpoint = OrderedDict()
|
61 |
-
for k, v in self.ds_engine.module.named_parameters():
|
62 |
-
if v.requires_grad:
|
63 |
-
checkpoint[k] = v
|
64 |
-
if current_step <= 0:
|
65 |
-
torch.save(checkpoint, f'{path}/pytorch_model.pt')
|
66 |
-
else:
|
67 |
-
torch.save(checkpoint, f'{path}/pytorch_model_ep{current_step}.pt')
|
68 |
-
# save tokenizer
|
69 |
-
self.model.llama_tokenizer.save_pretrained(path)
|
70 |
-
# save configuration
|
71 |
-
self.model.llama_model.config.save_pretrained(path)
|
72 |
-
print(f'[!] save model into {path}')
|
73 |
-
|
74 |
-
def load_stage_1_parameters(self, path):
|
75 |
-
delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
|
76 |
-
self.model.load_state_dict(delta_ckpt, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/openlamm.py
CHANGED
@@ -21,7 +21,7 @@ from .CLIP import load as load_clip
|
|
21 |
from .PROCESS import data
|
22 |
from .modeling_llama import LlamaForCausalLM
|
23 |
from .utils.pcl_utils import MEAN_COLOR_RGB, RandomCuboid, random_sampling
|
24 |
-
|
25 |
|
26 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
27 |
|
@@ -139,9 +139,9 @@ def make_prompt_start(system_header=False, vision_type='image', task_type='norma
|
|
139 |
PROMPT_START = f'### Human: {VISION_TAGS["sov"][vision_type]}'
|
140 |
if system_header:
|
141 |
if task_type == 'normal':
|
142 |
-
return f"{
|
143 |
else:
|
144 |
-
return [f"{
|
145 |
else:
|
146 |
return PROMPT_START
|
147 |
|
|
|
21 |
from .PROCESS import data
|
22 |
from .modeling_llama import LlamaForCausalLM
|
23 |
from .utils.pcl_utils import MEAN_COLOR_RGB, RandomCuboid, random_sampling
|
24 |
+
from .conversations import conversation_dict, default_conversation
|
25 |
|
26 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
27 |
|
|
|
139 |
PROMPT_START = f'### Human: {VISION_TAGS["sov"][vision_type]}'
|
140 |
if system_header:
|
141 |
if task_type == 'normal':
|
142 |
+
return f"{default_conversation.system}\n\n" + PROMPT_START
|
143 |
else:
|
144 |
+
return [f"{conversation_dict[task]}\n\n" + PROMPT_START for task in task_type]
|
145 |
else:
|
146 |
return PROMPT_START
|
147 |
|