DeepLearning101 commited on
Commit
ea0fb2f
1 Parent(s): 0d88c28

Upload 6 files

Browse files
models/instruction_prompting/causal_incontext.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./")
3
+ sys.path.append("../")
4
+ sys.path.append("../../")
5
+ sys.path.append("../../../")
6
+ import torch
7
+ import torch.nn as nn
8
+ import transformers
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple, Union
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers import AutoModelForCausalLM
13
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
14
+ from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model, GPT2LMHeadModel
15
+ from tools.runner_utils.log_util import logging
16
+ from tools.model_utils.parameter_freeze import ParameterFreeze
17
+
18
+ logger = logging.getLogger(__name__)
19
+ freezer = ParameterFreeze()
20
+
21
+ """
22
+ Function: Use Causal LM to prompt for cls
23
+ Notes:
24
+ - For classification, the model only calculate the loss at the position of label, the other position is set as -100
25
+ - During inference, generate result at the last position.
26
+ """
27
+
28
+
29
+ class GPT2ForInContextLearning(GPT2PreTrainedModel):
30
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
31
+
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+ self.transformer = GPT2Model(config)
35
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
36
+
37
+ # if self.config.use_freezing:
38
+ # self.transformer = freezer.freeze_lm(self.transformer)
39
+
40
+ # Model parallel
41
+ self.model_parallel = False
42
+ self.device_map = None
43
+
44
+ # Initialize weights and apply final processing
45
+ self.post_init()
46
+
47
+ def get_output_embeddings(self):
48
+ return self.lm_head
49
+
50
+ def set_output_embeddings(self, new_embeddings):
51
+ self.lm_head = new_embeddings
52
+
53
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
54
+ token_type_ids = kwargs.get("token_type_ids", None)
55
+ # only last token for inputs_ids if past is defined in kwargs
56
+ if past:
57
+ input_ids = input_ids[:, -1].unsqueeze(-1)
58
+ if token_type_ids is not None:
59
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
60
+
61
+ attention_mask = kwargs.get("attention_mask", None)
62
+ position_ids = kwargs.get("position_ids", None)
63
+
64
+ if attention_mask is not None and position_ids is None:
65
+ # create position_ids on the fly for batch generation
66
+ position_ids = attention_mask.long().cumsum(-1) - 1
67
+ position_ids.masked_fill_(attention_mask == 0, 1)
68
+ if past:
69
+ position_ids = position_ids[:, -1].unsqueeze(-1)
70
+ else:
71
+ position_ids = None
72
+ return {
73
+ "input_ids": input_ids,
74
+ "past_key_values": past,
75
+ "use_cache": kwargs.get("use_cache"),
76
+ "position_ids": position_ids,
77
+ "attention_mask": attention_mask,
78
+ "token_type_ids": token_type_ids,
79
+ }
80
+
81
+ def forward(
82
+ self,
83
+ input_ids: Optional[torch.LongTensor] = None,
84
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
85
+ attention_mask: Optional[torch.FloatTensor] = None,
86
+ token_type_ids: Optional[torch.LongTensor] = None,
87
+ position_ids: Optional[torch.LongTensor] = None,
88
+ head_mask: Optional[torch.FloatTensor] = None,
89
+ inputs_embeds: Optional[torch.FloatTensor] = None,
90
+ encoder_hidden_states: Optional[torch.Tensor] = None,
91
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
92
+ labels: Optional[torch.LongTensor] = None,
93
+ use_cache: Optional[bool] = None,
94
+ output_attentions: Optional[bool] = None,
95
+ output_hidden_states: Optional[bool] = None,
96
+ return_dict: Optional[bool] = None,
97
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
98
+ r"""
99
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
100
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
101
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
102
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
103
+ """
104
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
105
+
106
+ transformer_outputs = self.transformer(
107
+ input_ids,
108
+ past_key_values=past_key_values,
109
+ attention_mask=attention_mask,
110
+ token_type_ids=token_type_ids,
111
+ position_ids=position_ids,
112
+ head_mask=head_mask,
113
+ inputs_embeds=inputs_embeds,
114
+ encoder_hidden_states=encoder_hidden_states,
115
+ encoder_attention_mask=encoder_attention_mask,
116
+ use_cache=use_cache,
117
+ output_attentions=output_attentions,
118
+ output_hidden_states=output_hidden_states,
119
+ return_dict=return_dict,
120
+ )
121
+ hidden_states = transformer_outputs[0]
122
+
123
+ # Set device for model parallelism
124
+ if self.model_parallel:
125
+ torch.cuda.set_device(self.transformer.first_device)
126
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
127
+
128
+ lm_logits = self.lm_head(hidden_states)
129
+
130
+ loss = None
131
+ if labels is not None:
132
+ # Shift so that tokens < n predict n
133
+ shift_logits = lm_logits[..., :-1, :].contiguous()
134
+ shift_labels = labels[..., 1:].contiguous()
135
+ # print("shift_labels=", shift_labels)
136
+ # Flatten the tokens
137
+ loss_fct = CrossEntropyLoss()
138
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
139
+
140
+ if not return_dict:
141
+ output = (lm_logits,) + transformer_outputs[1:]
142
+ return ((loss,) + output) if loss is not None else output
143
+
144
+ return CausalLMOutputWithCrossAttentions(
145
+ loss=loss,
146
+ logits=lm_logits,
147
+ past_key_values=transformer_outputs.past_key_values,
148
+ hidden_states=transformer_outputs.hidden_states,
149
+ attentions=transformer_outputs.attentions,
150
+ cross_attentions=transformer_outputs.cross_attentions,
151
+ )
152
+
153
+
154
+ @staticmethod
155
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
156
+ """
157
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
158
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
159
+ beam_idx at every generation step.
160
+ """
161
+ return tuple(
162
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
163
+ for layer_past in past
164
+ )
165
+
166
+
167
+
168
+
169
+ if __name__ == "__main__":
170
+ from transformers import GPT2Tokenizer
171
+ tokenizer = GPT2Tokenizer.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
172
+ model = GPT2ForInContextLearning.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
173
+
174
+ # In-Context Learning for classification
175
+ # input_text = "The capital city of China is Beijing. \n\n The capital city of Japan is Tokyo. \n\n The capital city of America is"
176
+ input_text = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output:"
177
+ # input_text = "This film is wonderful.\n Great."
178
+ tokenizer.pad_token = tokenizer.eos_token
179
+ inputs = tokenizer(input_text, return_tensors="pt")
180
+ input_len = inputs["input_ids"].shape[-1]
181
+ gen_output = model.generate(**inputs, max_length=input_len + 10)
182
+ gen_result = tokenizer.decode(gen_output[0])
183
+ print("classification result:\n", gen_result)
184
+
185
+ # In-Context Learning for generation
186
+ input_text = "Please tell me what is the transformer? "
187
+ # input_text = "This film is wonderful.\n Great."
188
+ tokenizer.pad_token = tokenizer.eos_token
189
+ inputs = tokenizer(input_text, return_tensors="pt")
190
+ input_len = inputs["input_ids"].shape[-1]
191
+ gen_output = model.generate(**inputs, max_length=input_len + 60)
192
+ gen_result = tokenizer.decode(gen_output[0])
193
+ print("generation result:\n", gen_result)
models/instruction_prompting/incontext.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from typing import Optional, Tuple
4
+ from turtle import forward
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers import AutoModelForCausalLM
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
9
+
10
+
11
+ class GPT2ForInContextClassification(GPT2LMHeadModel):
12
+
13
+ def forward(
14
+ self,
15
+ input_ids: Optional[torch.LongTensor] = None, # input token id
16
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
17
+ attention_mask: Optional[torch.FloatTensor] = None,
18
+ token_type_ids: Optional[torch.LongTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ label_masks: Optional[torch.LongTensor] = None, # mask=1 means it should be calculated loss
21
+ options :Optional[list] = None, # 如果是分类任务,则可以添加候选label
22
+ output_attentions=None,
23
+ output_hidden_states=None,
24
+ return_dict=None,
25
+ ):
26
+ assert len(input_ids.shape) == 3 and input_ids.shape[1] == len(options) # [n, option_size, len]
27
+ batch_size = input_ids.shape[0]
28
+ option_size = input_ids.shape[1]
29
+ input_ids = input_ids.view(-1, input_ids.shape[1], input_ids.shape[2]) # [n*option_size, len]
30
+ attention_mask = attention_mask.view(-1, input_ids.shape[1], input_ids.shape[2]) if attention_mask is not None else None # [n*option_size, len]
31
+ token_type_ids = token_type_ids.view(-1, input_ids.shape[1], input_ids.shape[2]) if token_type_ids is not None else None# [n*option_size, len]
32
+ # labels = labels.view(-1, input_ids.shape[1], input_ids.shape[2]) # [n*option_size, len]
33
+
34
+ transformer_outputs = self.transformer(
35
+ input_ids,
36
+ past_key_values=past_key_values,
37
+ attention_mask=attention_mask,
38
+ token_type_ids=token_type_ids,
39
+ output_attentions=output_attentions,
40
+ output_hidden_states=output_hidden_states,
41
+ return_dict=return_dict,
42
+ )
43
+ hidden_states = transformer_outputs[0] # [n*option_size, len, hidden_size]
44
+ lm_logits = self.lm_head(hidden_states) # [n*option_size, len, vocab_size]
45
+ lm_logits = lm_logits.view(batch_size, option_size, input_ids.shape[-1], -1) # [n, option_size, len, vocab_size]
46
+
47
+ # print("len(input_ids)=", len(input_ids[0]))
48
+ # print("input_ids[-1]=", input_ids[0][-1])
49
+ print("lm_logits.shape=", lm_logits.shape)
50
+
51
+ losses = list()
52
+ if labels is not None:
53
+ for label, lm_logit in zip(labels, lm_logits):
54
+ # label: [option_size, len]
55
+ # lm_logit: [option_size, len, vocab_size]
56
+ shift_logits = lm_logit[..., :-1, :].contiguous()
57
+ # print("shift_logits.shape=", shift_logits.shape)
58
+ shift_labels = label[..., 1:].contiguous()
59
+ # print("shift_labels=", shift_labels)
60
+ # print("shift_labels.shape=", shift_labels.shape)
61
+ # Flatten the tokens
62
+ loss_fct = CrossEntropyLoss()
63
+ print("shift_logits.shape=", shift_logits.shape)
64
+ print("shift_labels.shape=", shift_labels.shape)
65
+ loss = [loss_fct(shift_logit.view(-1, shift_logit.size(-1)), shift_label.view(-1)) for shift_logit, shift_label in zip(shift_logits, shift_labels)]
66
+ loss = torch.stack(loss)
67
+ # print("loss=", loss)
68
+ if label_masks is not None:
69
+ loss = loss.view(lm_logits.size(0), lm_logits.size(1)) * label_masks # [option_size, len]
70
+ loss = torch.sum(loss, axis=1) / torch.sum(label_mask, axis=1) # [option_size]
71
+ losses.append(loss)
72
+ losses = torch.stack(losses) # [n, option_size]
73
+ # 将各个option的loss视为logit,loss越小,对应的概率应越大
74
+ loss_logits = torch.softmax(-losses, -1) # [n, option_size]
75
+ print("losses.shape=", losses.shape)
76
+ print("loss_logits.shape=", loss_logits.shape)
77
+
78
+ if not return_dict:
79
+ output = (lm_logits,) + transformer_outputs[1:]
80
+ return ((loss,) + output) if loss is not None else output
81
+
82
+ return CausalLMOutputWithCrossAttentions(
83
+ loss=losses,
84
+ logits=loss_logits,
85
+ past_key_values=transformer_outputs.past_key_values,
86
+ hidden_states=transformer_outputs.hidden_states,
87
+ attentions=transformer_outputs.attentions,
88
+ cross_attentions=transformer_outputs.cross_attentions,
89
+ )
90
+
91
+ if __name__ == "__main__":
92
+ from transformers import GPT2Tokenizer
93
+ tokenizer = GPT2Tokenizer.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
94
+ model = GPT2ForInContextClassification.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
95
+ # input_text = "The capital city of China is Beijing. The capital city of Japan is Tokyo. The capital city of America"
96
+ input_text1 = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output: Bad. \n\n Input: This film is not wonderful.\n Output: Great"
97
+ input_text2 = "What are follows emotions? \n\n Input: The book is very nice.\n Output: Great. \n\n Input: I never eat chocolate!\n Output: Bad. \n\n Input: This film is not wonderful.\n Output: Bad"
98
+ # input_text = "This film is wonderful.\n Great."
99
+ # input_text = "Mr. Chen was born in Shanghai. Obama was born in US. Jinping Xi was born in China."
100
+ tokenizer.pad_token = tokenizer.eos_token
101
+ inputs = tokenizer(
102
+ [input_text1, input_text2], return_tensors="pt",
103
+ max_length=60,
104
+ padding="max_length")
105
+ inputs["input_ids"] = inputs["input_ids"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1])
106
+ # inputs["token_type_ids"] = inputs["token_type_ids"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1])
107
+ inputs["attention_mask"] = inputs["attention_mask"].view(-1, inputs["input_ids"].shape[0], inputs["input_ids"].shape[1])
108
+ inputs["labels"] = inputs["input_ids"]
109
+ inputs["options"] = torch.Tensor([[0, 1], [0, 1]]).long()
110
+ print(inputs["input_ids"].shape)
111
+ label_mask = torch.zeros([1, 2, inputs["input_ids"].shape[2]])
112
+ # print(label_mask)
113
+ label_mask[0][0][20] = 1
114
+ label_mask[0][1][20] = 1
115
+ print(label_mask)
116
+ output = model(**inputs, return_dict=True)
117
+ # print(output["last_hidden_state"])
118
+ # print(output["last_hidden_state"].size())
119
+ # print(output["logits"])
120
+ # print(output["logits"].size())
121
+ losses, logits = output["loss"], output["logits"]
122
+ print("loss=", losses)
123
+ print("logits=", logits)
124
+ # gen_output = model.generate(**inputs, max_length=60)
125
+ # for i in range(len(gen_output)):
126
+ # gen_result = tokenizer.decode(gen_output[i])
127
+ # print("gen_result=", gen_result[len(inputs["input_ids"]):])
models/instruction_prompting/test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM
2
+ from transformers import GPT2Tokenizer
3
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2LMHeadModel
4
+
5
+ if __name__ == "__main__":
6
+ gpt2_tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
7
+ # gpt2_model = GPT2LMHeadModel.from_pretrained("/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2")
8
+ # # input_text = "The capital city of China is Beijing. The capital city of Japan is Tokyo. The capital city of America"
9
+ # input_text = "What are follows emotions? \n\n The book is very nice.\n great. \n\n I never eat chocolate!\n bad. \n\n This film is wonderful.\n Great"
10
+ # # input_text = "Mr. Chen was born in Shanghai. Obama was born in US. Trump was born in"
11
+ # inputs = gpt2_tokenizer(input_text, return_tensors="pt")
12
+ # print(inputs)
13
+ # output = gpt2_model(**inputs)
14
+ # # print(output["last_hidden_state"])
15
+ # # print(output["last_hidden_state"].size())
16
+ # print(output["logits"])
17
+ # print(output["logits"].size())
18
+ # gen_output = gpt2_model.generate(**inputs, max_length=60)
19
+ # # gen_result = gpt2_tokenizer.convert_ids_to_tokens(gen_output[0])
20
+ # gen_result = gpt2_tokenizer.decode(gen_output[0])
21
+ # print(gen_result)
22
+
23
+
24
+ gpt2_tokenizer(
25
+ [["What are follows emotions?", "What are follows emotions?"], ["What are follows emotions?"]],
26
+ truncation=True,
27
+ max_length=30,
28
+ padding="max_length",
29
+ return_offsets_mapping=True
30
+ )
models/reinforcement_learning/actor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/5/6 3:53 p.m.
3
+ # @Author : JianingWang
4
+ # @File : actor.py
5
+
6
+ from typing import Optional, Tuple, Union
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers import AutoModelForCausalLM, AutoConfig
12
+ from models.basic_modules.generation import generate
13
+
14
+
15
+ """
16
+ Actor model.
17
+ """
18
+ class Actor(nn.Module):
19
+ """
20
+ Actor model base class.
21
+
22
+ Args:
23
+ model (nn.Module): Actor Model.
24
+ """
25
+
26
+ def __init__(self, model: nn.Module) -> None:
27
+ self.model = model
28
+
29
+ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
30
+ log_probs = F.log_softmax(logits, dim=-1)
31
+ log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
32
+ return log_probs_labels.squeeze(-1)
33
+
34
+ """
35
+ For generative model, needs generate function.
36
+ """
37
+ @torch.no_grad()
38
+ def generate(
39
+ self,
40
+ input_ids: torch.Tensor,
41
+ return_action_mask: bool = True,
42
+ **kwargs
43
+ ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
44
+ sequences = generate(self.model, input_ids, **kwargs)
45
+ attention_mask = None
46
+ pad_token_id = kwargs.get('pad_token_id', None)
47
+ if pad_token_id is not None:
48
+ attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
49
+ if not return_action_mask:
50
+ return sequences, attention_mask, None
51
+ input_len = input_ids.size(1)
52
+ eos_token_id = kwargs.get('eos_token_id', None)
53
+ if eos_token_id is None:
54
+ action_mask = torch.ones_like(sequences, dtype=torch.bool)
55
+ else:
56
+ # left padding may be applied, only mask action
57
+ action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
58
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
59
+ action_mask[:, :input_len] = False
60
+ action_mask = action_mask[:, 1:]
61
+ return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
62
+
63
+ def forward(self,
64
+ sequences: torch.LongTensor,
65
+ num_actions: int,
66
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
67
+ """Returns action log probs
68
+ """
69
+ output = self.model(sequences, attention_mask=attention_mask)
70
+ logits = output['logits']
71
+ log_probs = self.log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
72
+ return log_probs[:, -num_actions:]
73
+
74
+ def get_base_model(self):
75
+ return self.model
76
+
77
+
78
+ """
79
+ Causal LM as a actor, e.g., GPT-2, OPT, BLOOM, etc.
80
+ """
81
+ class CausalActor(Actor):
82
+ """
83
+ Causal LM Actor model.
84
+
85
+ Args:
86
+ pretrained (str): Pretrained model name or path.
87
+ config (AutoConfig): Model config.
88
+ checkpoint (bool): Enable gradient checkpointing.
89
+ """
90
+
91
+ def __init__(self,
92
+ pretrained: str = None,
93
+ config: Optional[AutoConfig] = None,
94
+ checkpoint: bool = False) -> None:
95
+ if pretrained is not None:
96
+ model = AutoModelForCausalLM.from_pretrained(pretrained)
97
+ elif config is not None:
98
+ model = AutoModelForCausalLM(config)
99
+ else:
100
+ model = AutoModelForCausalLM(AutoConfig())
101
+ if checkpoint:
102
+ model.gradient_checkpointing_enable()
103
+ super().__init__(model)
models/reinforcement_learning/critic.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/5/6 4:12 p.m.
3
+ # @Author : JianingWang
4
+ # @File : critic.py
5
+
6
+ from typing import Optional, Tuple, Union
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers import AutoModel, AutoConfig
12
+ from models.basic_modules.generation import generate
13
+
14
+
15
+ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
16
+ tensor = tensor * mask
17
+ tensor = tensor.sum(dim=dim)
18
+ mask_sum = mask.sum(dim=dim)
19
+ mean = tensor / (mask_sum + 1e-8)
20
+ return mean
21
+
22
+
23
+ """
24
+ Critic model.
25
+ """
26
+ class Critic(nn.Module):
27
+ """
28
+ Critic model base class.
29
+
30
+ Args:
31
+ model (nn.Module): Critic model.
32
+ value_head (nn.Module): Value head to get value.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model: nn.Module,
38
+ value_head: nn.Module,
39
+ use_action_mask: bool = False,
40
+ ) -> None:
41
+
42
+ self.model = model
43
+ self.value_head = value_head # critic layer for predict value function
44
+ self.use_action_mask = use_action_mask
45
+
46
+ def forward(self,
47
+ sequences: torch.LongTensor,
48
+ action_mask: Optional[torch.Tensor] = None,
49
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
50
+ outputs = self.model(sequences, attention_mask=attention_mask)
51
+ last_hidden_states = outputs['last_hidden_state']
52
+
53
+ values = self.value_head(last_hidden_states).squeeze(-1)
54
+
55
+ if action_mask is not None and self.use_action_mask:
56
+ num_actions = action_mask.size(1)
57
+ prompt_mask = attention_mask[:, :-num_actions]
58
+ values = values[:, :-num_actions]
59
+ value = masked_mean(values, prompt_mask, dim=1)
60
+ return value
61
+
62
+ values = values[:, :-1]
63
+ value = values.mean(dim=1)
64
+ return value
65
+
66
+
67
+ """
68
+ Auto Model for Critic
69
+ """
70
+ class AutoModelCritic(Critic):
71
+ """
72
+ AutoModel Critic model.
73
+
74
+ Args:
75
+ pretrained (str): Pretrained model name or path.
76
+ config (AutoConfig): Model config.
77
+ checkpoint (bool): Enable gradient checkpointing.
78
+ """
79
+
80
+ def __init__(self,
81
+ pretrained: Optional[str] = None,
82
+ config: Optional[AutoConfig] = None,
83
+ checkpoint: bool = False,
84
+ lora_rank: int = 0,
85
+ lora_train_bias: str = 'none',
86
+ **kwargs) -> None:
87
+ if pretrained is not None:
88
+ model = AutoModel.from_pretrained(pretrained)
89
+ elif config is not None:
90
+ model = AutoModel(config)
91
+ else:
92
+ model = AutoModel(AutoConfig())
93
+ if checkpoint:
94
+ model.gradient_checkpointing_enable()
95
+ value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
96
+ super().__init__(model, value_head, **kwargs)
models/reinforcement_learning/reward_model.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/5/6 4:29 p.m.
3
+ # @Author : JianingWang
4
+ # @File : reward_model.py
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import AutoModel, AutoConfig
11
+ from loss.rl_loss import LogSigLoss, LogExpLoss
12
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel
13
+ from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model
14
+
15
+ """
16
+ RoERTa for Reward Model
17
+ """
18
+ class RobertaForReward(RobertaPreTrainedModel):
19
+ """
20
+ Reward model base class.
21
+
22
+ Args:
23
+ model (nn.Module): Reward model.
24
+ value_head (nn.Module): Value head to get reward score.
25
+ """
26
+
27
+ def __init__(self, config) -> None:
28
+ super().__init__(config)
29
+ self.config = config
30
+ self.roberta = RobertaModel(config)
31
+ self.value_head = nn.Linear(self.config.n_embd, 1)
32
+ self.init_weights()
33
+
34
+ def forward(
35
+ self,
36
+ chosen_sequences: torch.LongTensor,
37
+ chosen_attention_mask: Optional[torch.Tensor],
38
+ rejected_sequences: Optional[torch.LongTensor] = None,
39
+ rejected_attention_mask: Optional[torch.Tensor] = None,
40
+ ) -> torch.Tensor:
41
+ # obtain reward value of chosen sequence
42
+ chosen_outputs = self.roberta(chosen_sequences, attention_mask=chosen_attention_mask)
43
+ chosen_last_hidden_states = chosen_outputs['last_hidden_state']
44
+ chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1]
45
+ chosen_values = chosen_values.mean(dim=1).squeeze(1) # ensure shape is (B)
46
+
47
+ return_dict = {
48
+ "chosen_values": chosen_values,
49
+ }
50
+ # if has rejected, obtain reward of rejected sequence, and calculate the loss
51
+ if rejected_sequences is not None:
52
+ rejected_outputs = self.roberta(rejected_sequences, attention_mask=rejected_attention_mask)
53
+ rejected_last_hidden_states = rejected_outputs['last_hidden_state']
54
+ rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1]
55
+ rejected_values = rejected_values.mean(dim=1).squeeze(1) # ensure shape is (B)
56
+ return_dict["rejected_values"] = rejected_values
57
+
58
+ loss_fn = LogSigLoss()
59
+ loss = loss_fn(chosen_values, rejected_values)
60
+
61
+ return_dict["loss"] = loss
62
+
63
+ return return_dict
64
+
65
+
66
+ """
67
+ GPT2 for Reward Model
68
+ """
69
+ class GPT2ForReward(GPT2PreTrainedModel):
70
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
71
+ """
72
+ Reward model base class.
73
+
74
+ Args:
75
+ model (nn.Module): Reward model.
76
+ value_head (nn.Module): Value head to get reward score.
77
+ """
78
+
79
+ def __init__(self, config) -> None:
80
+ super().__init__(config)
81
+ self.config = config
82
+ self.transformer = GPT2Model(config)
83
+ self.value_head = nn.Linear(self.config.n_embd, 1)
84
+
85
+ # Model parallel
86
+ self.model_parallel = False
87
+ self.device_map = None
88
+
89
+ self.post_init()
90
+
91
+ def forward(
92
+ self,
93
+ chosen_sequences: torch.LongTensor,
94
+ chosen_attention_mask: Optional[torch.Tensor],
95
+ rejected_sequences: Optional[torch.LongTensor] = None,
96
+ rejected_attention_mask: Optional[torch.Tensor] = None,
97
+ ) -> torch.Tensor:
98
+ # obtain reward value of chosen sequence
99
+ chosen_outputs = self.transformer(chosen_sequences, attention_mask=chosen_attention_mask)
100
+ chosen_last_hidden_states = chosen_outputs['last_hidden_state']
101
+ chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1]
102
+ chosen_values = chosen_values.mean(dim=1).squeeze(1) # ensure shape is (B)
103
+
104
+ return_dict = {
105
+ "chosen_values": chosen_values,
106
+ }
107
+ # if has rejected, obtain reward of rejected sequence, and calculate the loss
108
+ if rejected_sequences is not None:
109
+ rejected_outputs = self.transformer(rejected_sequences, attention_mask=rejected_attention_mask)
110
+ rejected_last_hidden_states = rejected_outputs['last_hidden_state']
111
+ rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1]
112
+ rejected_values = rejected_values.mean(dim=1).squeeze(1) # ensure shape is (B)
113
+ return_dict["rejected_values"] = rejected_values
114
+ loss_fn = LogSigLoss()
115
+ loss = loss_fn(chosen_values, rejected_values)
116
+
117
+ return_dict["loss"] = loss
118
+
119
+ return return_dict
120
+
121
+ @staticmethod
122
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
123
+ """
124
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
125
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
126
+ beam_idx at every generation step.
127
+ """
128
+ return tuple(
129
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
130
+ for layer_past in past
131
+ )