Spaces:
Runtime error
Runtime error
Commit
·
982a4de
0
Parent(s):
Duplicate from Linly-AI/Linly-ChatFlow
Browse filesCo-authored-by: yuhaofeng <yuhaofeng-shiba@users.noreply.huggingface.co>
- .gitattributes +34 -0
- README.md +14 -0
- app.py +54 -0
- config/llama_13b_config.json +21 -0
- config/llama_7b.json +21 -0
- generate.py +143 -0
- model_file/chatllama_7b.bin +3 -0
- model_file/tokenizer.model +3 -0
- models/llama.py +197 -0
- models/norm.py +16 -0
- models/rope.py +30 -0
- models/tokenize.py +40 -0
- requirements.txt +7 -0
- utils.py +143 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Linly ChatFlow 7B
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.38.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: gpl-3.0
|
11 |
+
duplicated_from: Linly-AI/Linly-ChatFlow
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
+
|
9 |
+
|
10 |
+
def init_model():
|
11 |
+
model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
|
12 |
+
torch_dtype=torch.bfloat16, trust_remote_code=True)
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
|
14 |
+
return model, tokenizer
|
15 |
+
|
16 |
+
|
17 |
+
def process(message, history):
|
18 |
+
input_prompt = ""
|
19 |
+
for interaction in history:
|
20 |
+
input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
|
21 |
+
input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
|
22 |
+
inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
|
23 |
+
try:
|
24 |
+
generate_ids = model.generate(inputs.input_ids, max_new_tokens=2048, do_sample=True, top_k=20, top_p=0.84,
|
25 |
+
temperature=1, repetition_penalty=1.15, eos_token_id=2, bos_token_id=1,
|
26 |
+
pad_token_id=0)
|
27 |
+
response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
28 |
+
print('log:', response)
|
29 |
+
response = response.split("### Response:")[-1]
|
30 |
+
return response
|
31 |
+
except:
|
32 |
+
return "Error: 会话超长,请重试!"
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
|
37 |
+
"请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
|
38 |
+
"如何使用ssh -L,请用具体例子说明",
|
39 |
+
"应对压力最有效的方法是什么?"]
|
40 |
+
model, tokenizer = init_model()
|
41 |
+
demo = gr.ChatInterface(
|
42 |
+
process,
|
43 |
+
chatbot=gr.Chatbot(height=600),
|
44 |
+
textbox=gr.Textbox(placeholder="Input", container=False, scale=7),
|
45 |
+
title="Linly ChatFlow",
|
46 |
+
description="",
|
47 |
+
theme="soft",
|
48 |
+
examples=examples,
|
49 |
+
cache_examples=True,
|
50 |
+
retry_btn="Retry",
|
51 |
+
undo_btn="Delete Previous",
|
52 |
+
clear_btn="Clear",
|
53 |
+
)
|
54 |
+
demo.queue(concurrency_count=75).launch()
|
config/llama_13b_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"emb_size": 5120,
|
3 |
+
"feedforward_size": 13824,
|
4 |
+
"hidden_size": 5120,
|
5 |
+
"hidden_act": "silu",
|
6 |
+
"heads_num": 40,
|
7 |
+
"layers_num": 40,
|
8 |
+
"dropout": 0.1,
|
9 |
+
"data_processor": "lm",
|
10 |
+
"max_seq_length": 2048,
|
11 |
+
"embedding": ["word"],
|
12 |
+
"remove_transformer_bias": true,
|
13 |
+
"remove_embedding_layernorm": true,
|
14 |
+
"rotary_position_embedding": true,
|
15 |
+
"encoder": "transformer",
|
16 |
+
"feed_forward": "gated",
|
17 |
+
"mask": "causal",
|
18 |
+
"layernorm_positioning": "pre",
|
19 |
+
"layernorm": "rms",
|
20 |
+
"target": ["lm"]
|
21 |
+
}
|
config/llama_7b.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"emb_size": 4096,
|
3 |
+
"feedforward_size": 11008,
|
4 |
+
"hidden_size": 4096,
|
5 |
+
"hidden_act": "silu",
|
6 |
+
"heads_num": 32,
|
7 |
+
"layers_num": 32,
|
8 |
+
"dropout": 0.1,
|
9 |
+
"data_processor": "lm",
|
10 |
+
"max_seq_length": 2048,
|
11 |
+
"embedding": ["word"],
|
12 |
+
"remove_transformer_bias": true,
|
13 |
+
"remove_embedding_layernorm": true,
|
14 |
+
"rotary_position_embedding": true,
|
15 |
+
"encoder": "transformer",
|
16 |
+
"feed_forward": "gated",
|
17 |
+
"mask": "causal",
|
18 |
+
"layernorm_positioning": "pre",
|
19 |
+
"layernorm": "rms",
|
20 |
+
"target": ["lm"]
|
21 |
+
}
|
generate.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def apply_temperature(scores, tempt):
|
6 |
+
if tempt > 0:
|
7 |
+
scores = scores / tempt
|
8 |
+
return scores
|
9 |
+
|
10 |
+
|
11 |
+
def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
12 |
+
if top_p > 0 and top_p < 1:
|
13 |
+
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
14 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
15 |
+
|
16 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
17 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
18 |
+
if min_tokens_to_keep > 1:
|
19 |
+
# Keep at least min_tokens_to_keep
|
20 |
+
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
21 |
+
|
22 |
+
# scatter sorted tensors to original indexing
|
23 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
24 |
+
1, sorted_indices, sorted_indices_to_remove
|
25 |
+
)
|
26 |
+
scores = scores.masked_fill(indices_to_remove, filter_value)
|
27 |
+
return scores
|
28 |
+
|
29 |
+
|
30 |
+
def apply_top_k(logits, top_k):
|
31 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
32 |
+
if top_k > 0:
|
33 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
34 |
+
indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None]
|
35 |
+
logits[indices_to_remove] = -float("Inf")
|
36 |
+
|
37 |
+
return logits
|
38 |
+
|
39 |
+
def apply_advanced_repetition_penalty(
|
40 |
+
input_ids, scores, penalty_range, penalty_slope, penalty
|
41 |
+
):
|
42 |
+
penalty_range = int(penalty_range)
|
43 |
+
clipped_penalty_range = min(input_ids.shape[-1], penalty_range)
|
44 |
+
|
45 |
+
if penalty != 1.0:
|
46 |
+
if penalty_range > 0:
|
47 |
+
if clipped_penalty_range < input_ids.shape[1]:
|
48 |
+
input_ids = input_ids[..., -clipped_penalty_range:]
|
49 |
+
|
50 |
+
if penalty_slope != 0:
|
51 |
+
_penalty = (
|
52 |
+
torch.arange(
|
53 |
+
penalty_range, dtype=scores.dtype, device=scores.device
|
54 |
+
)
|
55 |
+
/ (penalty_range - 1)
|
56 |
+
) * 2.0 - 1
|
57 |
+
_penalty = (penalty_slope * _penalty) / (
|
58 |
+
1 + torch.abs(_penalty) * (penalty_slope - 1)
|
59 |
+
)
|
60 |
+
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
|
61 |
+
penalty = _penalty[..., -clipped_penalty_range:]
|
62 |
+
|
63 |
+
score = torch.gather(scores, 1, input_ids)
|
64 |
+
score = torch.where(score <= 0, score * penalty, score / penalty)
|
65 |
+
scores.scatter_(1, input_ids, score)
|
66 |
+
|
67 |
+
return scores
|
68 |
+
|
69 |
+
|
70 |
+
class LmGeneration:
|
71 |
+
def __init__(self, model, tokenizer):
|
72 |
+
self.model = model
|
73 |
+
self.tokenizer = tokenizer
|
74 |
+
|
75 |
+
def generate(self, args, prompts, cut_off=None, cut_off_times=1):
|
76 |
+
if cut_off is not None:
|
77 |
+
cut_off_times = [cut_off_times for i in range(len(prompts))]
|
78 |
+
batch = len(prompts)
|
79 |
+
assert batch <= args.batch_size
|
80 |
+
|
81 |
+
prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
|
82 |
+
|
83 |
+
min_prompt_len = min([len(x) for x in prompt_tokens])
|
84 |
+
# max_prompt_len = max([len(x) for x in prompt_tokens])
|
85 |
+
|
86 |
+
total_len = args.seq_length
|
87 |
+
|
88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
+
tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long()
|
90 |
+
for idx, t in enumerate(prompt_tokens):
|
91 |
+
tokens[idx, : len(t)] = torch.tensor(t).long()
|
92 |
+
mask = tokens != self.tokenizer.pad_token
|
93 |
+
start_pos = min_prompt_len
|
94 |
+
prev_pos = 0
|
95 |
+
continue_exsample = [i for i in range(batch)]
|
96 |
+
with torch.no_grad():
|
97 |
+
for cur_pos in range(start_pos, total_len):
|
98 |
+
logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float()
|
99 |
+
next_token_scores = apply_top_k(logits, top_k=args.top_k)
|
100 |
+
next_token_scores = apply_top_p(next_token_scores, args.top_p)
|
101 |
+
next_token_scores = apply_temperature(next_token_scores, args.temperature)
|
102 |
+
next_token_scores = apply_advanced_repetition_penalty(
|
103 |
+
tokens[continue_exsample, :cur_pos],
|
104 |
+
next_token_scores,
|
105 |
+
args.repetition_penalty_range,
|
106 |
+
args.repetition_penalty_slope,
|
107 |
+
args.repetition_penalty
|
108 |
+
)
|
109 |
+
scores = F.softmax(next_token_scores, dim=-1)
|
110 |
+
next_token = torch.multinomial(scores, num_samples=1).squeeze(1)
|
111 |
+
next_token = next_token.reshape(-1)
|
112 |
+
next_token = torch.where(
|
113 |
+
mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token
|
114 |
+
)
|
115 |
+
tokens[continue_exsample, cur_pos] = next_token
|
116 |
+
prev_pos = cur_pos
|
117 |
+
# remove eos examples.
|
118 |
+
continue_exsample = []
|
119 |
+
for i, t in enumerate(tokens.tolist()):
|
120 |
+
try:
|
121 |
+
t.index(self.tokenizer.eos_token)
|
122 |
+
except ValueError:
|
123 |
+
if cut_off is not None:
|
124 |
+
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
|
125 |
+
if cut_off_times[i] == 1:
|
126 |
+
continue
|
127 |
+
else:
|
128 |
+
cut_off_times[i] -= 1
|
129 |
+
continue_exsample.append(i)
|
130 |
+
if len(continue_exsample) == 0:
|
131 |
+
break
|
132 |
+
|
133 |
+
decoder = []
|
134 |
+
for i, t in enumerate(tokens.tolist()):
|
135 |
+
t = t[: args.seq_length]
|
136 |
+
try:
|
137 |
+
t = t[: t.index(self.tokenizer.pad_token)]
|
138 |
+
t = t[: t.index(self.tokenizer.eos_token)]
|
139 |
+
except ValueError:
|
140 |
+
pass
|
141 |
+
decoder.append(self.tokenizer.decode(t))
|
142 |
+
|
143 |
+
return decoder
|
model_file/chatllama_7b.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18ab0a6fb112bcea1fae85716d3b822c1903a6cd95fda5a65fd9fe164b705037
|
3 |
+
size 13476956615
|
model_file/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
models/llama.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models.norm import RMSNorm
|
5 |
+
from models.rope import precompute_freqs_cis, apply_rotary_emb
|
6 |
+
import bitsandbytes as bnb
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
class NormalLinear(nn.Linear):
|
11 |
+
def reset_parameters(self) -> None:
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class BnbInt8Linear(bnb.nn.Linear8bitLt):
|
16 |
+
def __init__(self, *args, **kwargs):
|
17 |
+
super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs)
|
18 |
+
|
19 |
+
def reset_parameters(self) -> None:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
def get_linear_layer(use_int8):
|
24 |
+
if use_int8:
|
25 |
+
return BnbInt8Linear
|
26 |
+
return NormalLinear
|
27 |
+
|
28 |
+
|
29 |
+
class WordEmbedding(nn.Module):
|
30 |
+
def __init__(self, args):
|
31 |
+
super(WordEmbedding, self).__init__()
|
32 |
+
self.embedding = nn.Embedding(args.vocab_size, args.emb_size)
|
33 |
+
|
34 |
+
def forward(self, src):
|
35 |
+
emb = self.embedding(src)
|
36 |
+
return emb
|
37 |
+
|
38 |
+
|
39 |
+
class MultiHeadedAttention(nn.Module):
|
40 |
+
def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True):
|
41 |
+
super(MultiHeadedAttention, self).__init__()
|
42 |
+
self.heads_num = heads_num
|
43 |
+
|
44 |
+
self.per_head_size = attention_head_size
|
45 |
+
self.inner_hidden_size = heads_num * attention_head_size
|
46 |
+
|
47 |
+
Linear = get_linear_layer(use_int8)
|
48 |
+
self.linear_layers = nn.ModuleList(
|
49 |
+
[Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
|
50 |
+
)
|
51 |
+
|
52 |
+
self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
|
53 |
+
|
54 |
+
# add cache to reduce compute source.
|
55 |
+
self.cache_k = torch.zeros(
|
56 |
+
(args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
|
57 |
+
)
|
58 |
+
self.cache_v = torch.zeros(
|
59 |
+
(args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis):
|
63 |
+
batch_size, seq_length, _ = query.size()
|
64 |
+
heads_num = self.heads_num
|
65 |
+
per_head_size = self.per_head_size
|
66 |
+
query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \
|
67 |
+
for l, x in zip(self.linear_layers, (query, key, value))]
|
68 |
+
query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis)
|
69 |
+
if self.cache_k.device != key.device:
|
70 |
+
self.cache_k = self.cache_k.to(key)
|
71 |
+
if self.cache_v.device != value.device:
|
72 |
+
self.cache_v = self.cache_v.to(value)
|
73 |
+
|
74 |
+
self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key
|
75 |
+
self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value
|
76 |
+
|
77 |
+
key = self.cache_k[continue_exsample, : start_pos + seq_length]
|
78 |
+
value = self.cache_v[continue_exsample, : start_pos + seq_length]
|
79 |
+
|
80 |
+
query, key, value = [x.transpose(1, 2) for x in (query, key, value)]
|
81 |
+
|
82 |
+
scores = torch.matmul(query, key.transpose(-2, -1))
|
83 |
+
scores = scores / math.sqrt(float(per_head_size))
|
84 |
+
if mask is not None:
|
85 |
+
scores += mask
|
86 |
+
# probs = nn.Softmax(dim=-1)(scores)
|
87 |
+
probs = F.softmax(scores.float(), dim=-1).type_as(query)
|
88 |
+
output = torch.matmul(probs, value).transpose(1, 2).\
|
89 |
+
contiguous().view(batch_size, seq_length, -1)
|
90 |
+
return self.final_linear(output)
|
91 |
+
|
92 |
+
|
93 |
+
class GatedFeedForward(nn.Module):
|
94 |
+
def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True):
|
95 |
+
super(GatedFeedForward, self).__init__()
|
96 |
+
Linear = get_linear_layer(use_int8)
|
97 |
+
self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias)
|
98 |
+
self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias)
|
99 |
+
self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias)
|
100 |
+
self.act = F.silu
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
# gate = self.act(self.linear_gate(x))
|
104 |
+
gate = self.act(self.linear_gate(x)).type_as(x)
|
105 |
+
inter_linear = self.linear_1(x)
|
106 |
+
inter = gate * inter_linear
|
107 |
+
output = self.linear_2(inter)
|
108 |
+
return output
|
109 |
+
|
110 |
+
|
111 |
+
class TransformerLayer(nn.Module):
|
112 |
+
def __init__(self, args):
|
113 |
+
super(TransformerLayer, self).__init__()
|
114 |
+
|
115 |
+
if hasattr(args, "attention_head_size"):
|
116 |
+
attention_head_size = args.attention_head_size
|
117 |
+
else:
|
118 |
+
attention_head_size = args.hidden_size // args.heads_num
|
119 |
+
|
120 |
+
has_bias = bool(1 - args.remove_transformer_bias)
|
121 |
+
# Multi-head Attention
|
122 |
+
self.self_attn = MultiHeadedAttention(
|
123 |
+
args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias,
|
124 |
+
use_int8=args.use_int8
|
125 |
+
)
|
126 |
+
|
127 |
+
# FFN
|
128 |
+
self.feed_forward = GatedFeedForward(
|
129 |
+
args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8
|
130 |
+
)
|
131 |
+
|
132 |
+
self.layer_norm_1 = RMSNorm(args.hidden_size)
|
133 |
+
self.layer_norm_2 = RMSNorm(args.hidden_size)
|
134 |
+
|
135 |
+
def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None):
|
136 |
+
inter = self.layer_norm_1(hidden)
|
137 |
+
inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis)
|
138 |
+
hidden = hidden + inter
|
139 |
+
output = self.layer_norm_2(hidden)
|
140 |
+
output = self.feed_forward(output) + hidden
|
141 |
+
return output
|
142 |
+
|
143 |
+
|
144 |
+
class TransformerEncoder(nn.Module):
|
145 |
+
def __init__(self, args):
|
146 |
+
super(TransformerEncoder, self).__init__()
|
147 |
+
self.mask = args.mask
|
148 |
+
self.layers_num = args.layers_num
|
149 |
+
|
150 |
+
self.transformer = nn.ModuleList(
|
151 |
+
[TransformerLayer(args) for _ in range(self.layers_num)]
|
152 |
+
)
|
153 |
+
|
154 |
+
self.layer_norm = RMSNorm(args.hidden_size)
|
155 |
+
self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
|
156 |
+
|
157 |
+
def forward(self, emb, start_pos, continue_exsample):
|
158 |
+
batch_size, seq_length, _ = emb.size()
|
159 |
+
mask = None
|
160 |
+
if seq_length > 1:
|
161 |
+
mask = torch.ones(seq_length, seq_length, device=emb.device)
|
162 |
+
mask = torch.tril(mask)
|
163 |
+
mask = (1.0 - mask) * -10000
|
164 |
+
mask = mask.repeat(batch_size, 1, 1, 1)
|
165 |
+
|
166 |
+
hidden = emb
|
167 |
+
freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device)
|
168 |
+
|
169 |
+
for i in range(self.layers_num):
|
170 |
+
hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis)
|
171 |
+
return self.layer_norm(hidden)
|
172 |
+
|
173 |
+
|
174 |
+
class LmOutput(nn.Module):
|
175 |
+
def __init__(self, args):
|
176 |
+
super(LmOutput, self).__init__()
|
177 |
+
# update: lm output not use int8
|
178 |
+
Linear = get_linear_layer(False)
|
179 |
+
self.lm = Linear(args.hidden_size, args.vocab_size, bias=False)
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
return self.lm(x[:, -1, :])
|
183 |
+
|
184 |
+
|
185 |
+
class LLaMa(nn.Module):
|
186 |
+
def __init__(self, args):
|
187 |
+
super(LLaMa, self).__init__()
|
188 |
+
self.embedding = WordEmbedding(args)
|
189 |
+
self.encoder = TransformerEncoder(args)
|
190 |
+
self.target = LmOutput(args)
|
191 |
+
|
192 |
+
#@torch.inference_mode()
|
193 |
+
def forward(self, src, start_pos, continue_exsample):
|
194 |
+
emb = self.embedding(src)
|
195 |
+
output = self.encoder(emb, start_pos, continue_exsample)
|
196 |
+
output = self.target(output)
|
197 |
+
return output
|
models/norm.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(torch.nn.Module):
|
6 |
+
def __init__(self, hidden_size, eps=1e-6):
|
7 |
+
super().__init__()
|
8 |
+
self.eps = eps
|
9 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
10 |
+
|
11 |
+
def _norm(self, x):
|
12 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
output = self._norm(x.float()).type_as(x)
|
16 |
+
return output * self.weight
|
models/rope.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
5 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
6 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
7 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
8 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
9 |
+
return freqs_cis
|
10 |
+
|
11 |
+
|
12 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
13 |
+
ndim = x.ndim
|
14 |
+
assert 0 <= 1 < ndim
|
15 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
16 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
17 |
+
return freqs_cis.view(*shape)
|
18 |
+
|
19 |
+
|
20 |
+
def apply_rotary_emb(
|
21 |
+
xq: torch.Tensor,
|
22 |
+
xk: torch.Tensor,
|
23 |
+
freqs_cis: torch.Tensor,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
26 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
27 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
28 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
29 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
30 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
models/tokenize.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from
|
2 |
+
# https://github.com/tloen/llama-int8/blob/ce74669c767e42b5082391dd0cfcb621ba40c7f9/llama/tokenizer.py
|
3 |
+
|
4 |
+
from sentencepiece import SentencePieceProcessor
|
5 |
+
from logging import getLogger
|
6 |
+
from typing import List
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
logger = getLogger()
|
11 |
+
|
12 |
+
|
13 |
+
class Tokenizer:
|
14 |
+
def __init__(self, model_path: str):
|
15 |
+
# reload tokenizer
|
16 |
+
assert os.path.isfile(model_path), model_path
|
17 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
18 |
+
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
19 |
+
|
20 |
+
# BOS / EOS token IDs
|
21 |
+
self.n_words: int = self.sp_model.vocab_size()
|
22 |
+
self.bos_id: int = self.sp_model.bos_id()
|
23 |
+
self.eos_id: int = self.sp_model.eos_id()
|
24 |
+
self.pad_id: int = self.sp_model.pad_id()
|
25 |
+
logger.info(
|
26 |
+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
27 |
+
)
|
28 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
29 |
+
|
30 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
31 |
+
assert type(s) is str
|
32 |
+
t = self.sp_model.encode(s)
|
33 |
+
if bos:
|
34 |
+
t = [self.bos_id] + t
|
35 |
+
if eos:
|
36 |
+
t = t + [self.eos_id]
|
37 |
+
return t
|
38 |
+
|
39 |
+
def decode(self, t: List[int]) -> str:
|
40 |
+
return self.sp_model.decode(t)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch==1.11.0
|
3 |
+
bitsandbytes==0.37.2
|
4 |
+
sentencepiece
|
5 |
+
argparse
|
6 |
+
accelerate==0.21.0
|
7 |
+
transformers==4.31.0
|
utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sys
|
3 |
+
from argparse import Namespace
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
def load_hyperparam(default_args):
|
9 |
+
"""
|
10 |
+
Load arguments form argparse and config file
|
11 |
+
Priority: default options < config file < command line args
|
12 |
+
"""
|
13 |
+
with open(default_args.config_path, mode="r", encoding="utf-8") as f:
|
14 |
+
config_args_dict = json.load(f)
|
15 |
+
|
16 |
+
default_args_dict = vars(default_args)
|
17 |
+
|
18 |
+
command_line_args_dict = {k: default_args_dict[k] for k in [
|
19 |
+
a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a)
|
20 |
+
]}
|
21 |
+
default_args_dict.update(config_args_dict)
|
22 |
+
default_args_dict.update(command_line_args_dict)
|
23 |
+
args = Namespace(**default_args_dict)
|
24 |
+
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
|
29 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
30 |
+
|
31 |
+
# copy state_dict so _load_from_state_dict can modify it
|
32 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
33 |
+
metadata = getattr(state_dict, "_metadata", None)
|
34 |
+
state_dict = state_dict.copy()
|
35 |
+
state_dict['target.lm.weight'] = state_dict['target.lm.output_layer.weight']
|
36 |
+
del state_dict['target.lm.output_layer.weight']
|
37 |
+
state_dict['embedding.embedding.weight'] = state_dict['embedding.word.embedding.weight']
|
38 |
+
del state_dict['embedding.word.embedding.weight']
|
39 |
+
|
40 |
+
if metadata is not None:
|
41 |
+
metadata['embedding.embedding'] = metadata['embedding.word.embedding']
|
42 |
+
metadata['target.lm'] = metadata['target.lm.output_layer']
|
43 |
+
if metadata.get('embedding.dropout', None) is not None:
|
44 |
+
del metadata['embedding.dropout']
|
45 |
+
del metadata['embedding.word']
|
46 |
+
del metadata['embedding.word.embedding']
|
47 |
+
del metadata['target.lm.output_layer']
|
48 |
+
del metadata['target.lm.softmax']
|
49 |
+
del metadata['target.lm.criterion']
|
50 |
+
state_dict._metadata = metadata
|
51 |
+
|
52 |
+
error_msgs = []
|
53 |
+
|
54 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
55 |
+
# so we need to apply the function recursively.
|
56 |
+
def load(module, state_dict, prefix=""):
|
57 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
58 |
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
59 |
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
60 |
+
# state_dict
|
61 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
62 |
+
import deepspeed
|
63 |
+
# In sharded models, each shard has only part of the full state_dict, so only gather
|
64 |
+
# parameters that are in the current state_dict.
|
65 |
+
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
66 |
+
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
67 |
+
if len(params_to_gather) > 0:
|
68 |
+
# because zero3 puts placeholders in model params, this context
|
69 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
70 |
+
# the state dict and then re-partitions them again
|
71 |
+
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
72 |
+
if torch.distributed.get_rank() == 0:
|
73 |
+
module._load_from_state_dict(*args)
|
74 |
+
|
75 |
+
for name, child in module._modules.items():
|
76 |
+
if child is not None:
|
77 |
+
load(child, state_dict, prefix + name + ".")
|
78 |
+
|
79 |
+
load(model_to_load, state_dict, prefix=start_prefix)
|
80 |
+
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
81 |
+
# it's safe to delete it.
|
82 |
+
del state_dict
|
83 |
+
|
84 |
+
return model_to_load
|
85 |
+
|
86 |
+
|
87 |
+
def convert_normal_parameter_to_int8(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
|
88 |
+
import bitsandbytes as bnb
|
89 |
+
modules_to_not_convert = ["lm"] if modules_to_not_convert is None else modules_to_not_convert
|
90 |
+
for name, module in model.named_children():
|
91 |
+
if current_key_name is None:
|
92 |
+
current_key_name = []
|
93 |
+
current_key_name.append(name)
|
94 |
+
|
95 |
+
if len(list(module.children())) > 0:
|
96 |
+
convert_normal_parameter_to_int8(module, threshold, modules_to_not_convert, current_key_name)
|
97 |
+
|
98 |
+
if isinstance(module, bnb.nn.Linear8bitLt) and name not in modules_to_not_convert:
|
99 |
+
# Check if the current key is not in the `modules_to_not_convert`
|
100 |
+
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
101 |
+
model._modules[name].weight = bnb.nn.Int8Params(
|
102 |
+
module.weight.data,
|
103 |
+
requires_grad=False,
|
104 |
+
has_fp16_weights=False
|
105 |
+
)
|
106 |
+
# Force requires grad to False to avoid unexpected errors
|
107 |
+
model._modules[name].requires_grad_(False)
|
108 |
+
# Remove the last key for recursion
|
109 |
+
current_key_name.pop(-1)
|
110 |
+
return model
|
111 |
+
|
112 |
+
|
113 |
+
def load_model(model, model_path):
|
114 |
+
if os.path.isdir(model_path):
|
115 |
+
index_filename = os.path.join(model_path, 'pytorch_model.bin.index.json')
|
116 |
+
with open(index_filename, "r") as f:
|
117 |
+
index = json.loads(f.read())
|
118 |
+
shard_filenames = sorted(set(index["weight_map"].values()))
|
119 |
+
shard_filenames = [os.path.join(model_path, f) for f in shard_filenames]
|
120 |
+
for shard_file in shard_filenames:
|
121 |
+
shard_checkpoint = torch.load(shard_file, map_location='cpu')
|
122 |
+
for name, parameter in model.named_parameters():
|
123 |
+
if shard_checkpoint.get(name, None) is not None:
|
124 |
+
if 'target' in name:
|
125 |
+
parameter.data = shard_checkpoint['target.lm.output_layer.weight']
|
126 |
+
elif 'embedding' in name:
|
127 |
+
parameter.data = shard_checkpoint['embedding.word.embedding.weight']
|
128 |
+
else:
|
129 |
+
parameter.data = shard_checkpoint[name]
|
130 |
+
parameter.requires_grad = False
|
131 |
+
del shard_checkpoint
|
132 |
+
else:
|
133 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
134 |
+
for parameter_name, parameter in model.named_parameters():
|
135 |
+
if 'target' in parameter_name:
|
136 |
+
parameter.data = checkpoint['target.lm.output_layer.weight']
|
137 |
+
elif 'embedding' in parameter_name:
|
138 |
+
parameter.data = checkpoint['embedding.word.embedding.weight']
|
139 |
+
else:
|
140 |
+
parameter.data = checkpoint[parameter_name]
|
141 |
+
parameter.requires_grad = False
|
142 |
+
del checkpoint
|
143 |
+
return model
|