Update modeling_codeshell.py
Browse files- modeling_codeshell.py +207 -5
modeling_codeshell.py
CHANGED
@@ -30,14 +30,21 @@
|
|
30 |
# See the License for the specific language governing permissions and
|
31 |
# limitations under the License.
|
32 |
"""PyTorch CodeShell model."""
|
|
|
33 |
import math
|
34 |
-
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
35 |
|
36 |
import torch
|
37 |
import torch.utils.checkpoint
|
38 |
from torch import nn
|
39 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
40 |
|
|
|
|
|
|
|
41 |
from transformers.activations import ACT2FN
|
42 |
from transformers.modeling_outputs import (
|
43 |
BaseModelOutputWithPastAndCrossAttentions,
|
@@ -50,7 +57,6 @@ from transformers.utils import (
|
|
50 |
)
|
51 |
from .configuration_codeshell import CodeShellConfig
|
52 |
|
53 |
-
|
54 |
# Fused kernels
|
55 |
# Use separate functions for each case because conditionals prevent kernel fusion.
|
56 |
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
@@ -446,7 +452,7 @@ class CodeShellPreTrainedModel(PreTrainedModel):
|
|
446 |
|
447 |
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->Shell
|
448 |
def _set_gradient_checkpointing(self, module, value=False):
|
449 |
-
if isinstance(module,
|
450 |
module.gradient_checkpointing = value
|
451 |
|
452 |
|
@@ -739,6 +745,62 @@ class CodeShellModel(CodeShellPreTrainedModel):
|
|
739 |
hidden_states=all_hidden_states,
|
740 |
attentions=all_self_attentions,
|
741 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
742 |
|
743 |
|
744 |
@add_start_docstrings(
|
@@ -762,10 +824,10 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
|
|
762 |
def quantize(self, bits: int):
|
763 |
try:
|
764 |
import bitsandbytes
|
765 |
-
from .quantizer import
|
766 |
except ImportError:
|
767 |
raise ImportError(f"Needs bitsandbytes to run quantize.")
|
768 |
-
return
|
769 |
|
770 |
def get_output_embeddings(self):
|
771 |
return self.lm_head
|
@@ -882,3 +944,143 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
|
|
882 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
883 |
)
|
884 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# See the License for the specific language governing permissions and
|
31 |
# limitations under the License.
|
32 |
"""PyTorch CodeShell model."""
|
33 |
+
import os
|
34 |
import math
|
35 |
+
from typing import List, Optional, Tuple, Union, Callable
|
36 |
+
from threading import Thread
|
37 |
+
from queue import Queue
|
38 |
+
|
39 |
|
40 |
import torch
|
41 |
import torch.utils.checkpoint
|
42 |
from torch import nn
|
43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
44 |
|
45 |
+
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
|
46 |
+
from transformers.generation.utils import GenerationConfig
|
47 |
+
|
48 |
from transformers.activations import ACT2FN
|
49 |
from transformers.modeling_outputs import (
|
50 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
57 |
)
|
58 |
from .configuration_codeshell import CodeShellConfig
|
59 |
|
|
|
60 |
# Fused kernels
|
61 |
# Use separate functions for each case because conditionals prevent kernel fusion.
|
62 |
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
|
|
452 |
|
453 |
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->Shell
|
454 |
def _set_gradient_checkpointing(self, module, value=False):
|
455 |
+
if isinstance(module, CodeShellModel):
|
456 |
module.gradient_checkpointing = value
|
457 |
|
458 |
|
|
|
745 |
hidden_states=all_hidden_states,
|
746 |
attentions=all_self_attentions,
|
747 |
)
|
748 |
+
|
749 |
+
class EndOfFunctionCriteria(StoppingCriteria):
|
750 |
+
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
|
751 |
+
def __init__(self, input_lengths, eof_strings, tokenizer):
|
752 |
+
self.input_lengths = input_lengths
|
753 |
+
self.eof_strings = eof_strings
|
754 |
+
self.tokenizer = tokenizer
|
755 |
+
|
756 |
+
def __call__(self, input_ids, scores, **kwargs):
|
757 |
+
"""Returns true if all generated sequences contain any of the end-of-function strings."""
|
758 |
+
decoded_generations = []
|
759 |
+
for _input_ids, input_length in zip(input_ids, self.input_lengths):
|
760 |
+
decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
|
761 |
+
done = []
|
762 |
+
for decoded_generation in decoded_generations:
|
763 |
+
done.append(
|
764 |
+
any(
|
765 |
+
[
|
766 |
+
stop_string in decoded_generation
|
767 |
+
for stop_string in self.eof_strings
|
768 |
+
]
|
769 |
+
)
|
770 |
+
)
|
771 |
+
return all(done)
|
772 |
+
|
773 |
+
class TextIterStreamer:
|
774 |
+
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
|
775 |
+
self.tokenizer = tokenizer
|
776 |
+
self.skip_prompt = skip_prompt
|
777 |
+
self.skip_special_tokens = skip_special_tokens
|
778 |
+
self.tokens = []
|
779 |
+
self.text_queue = Queue()
|
780 |
+
self.next_tokens_are_prompt = True
|
781 |
+
|
782 |
+
def put(self, value):
|
783 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
784 |
+
self.next_tokens_are_prompt = False
|
785 |
+
else:
|
786 |
+
if len(value.shape) > 1:
|
787 |
+
value = value[0]
|
788 |
+
self.tokens.extend(value.tolist())
|
789 |
+
self.text_queue.put(
|
790 |
+
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
|
791 |
+
|
792 |
+
def end(self):
|
793 |
+
self.text_queue.put(None)
|
794 |
+
|
795 |
+
def __iter__(self):
|
796 |
+
return self
|
797 |
+
|
798 |
+
def __next__(self):
|
799 |
+
value = self.text_queue.get()
|
800 |
+
if value is None:
|
801 |
+
raise StopIteration()
|
802 |
+
else:
|
803 |
+
return value
|
804 |
|
805 |
|
806 |
@add_start_docstrings(
|
|
|
824 |
def quantize(self, bits: int):
|
825 |
try:
|
826 |
import bitsandbytes
|
827 |
+
from .quantizer import quantize
|
828 |
except ImportError:
|
829 |
raise ImportError(f"Needs bitsandbytes to run quantize.")
|
830 |
+
return quantize(self, bits)
|
831 |
|
832 |
def get_output_embeddings(self):
|
833 |
return self.lm_head
|
|
|
944 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
945 |
)
|
946 |
return reordered_past
|
947 |
+
|
948 |
+
|
949 |
+
def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
|
950 |
+
user_name = "\n## human:"
|
951 |
+
ai_name = "\n## assistant: "
|
952 |
+
stop = '|<end>|'
|
953 |
+
|
954 |
+
prompt = ''
|
955 |
+
for q, r in history:
|
956 |
+
prompt += f"{user_name}{q}{stop}"
|
957 |
+
prompt += f"{ai_name}{r}{stop}"
|
958 |
+
prompt += f"{user_name}{query}{stop}"
|
959 |
+
prompt += ai_name.rstrip()
|
960 |
+
|
961 |
+
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
962 |
+
max_input_tokens = self.config.n_positions - max_new_tokens
|
963 |
+
|
964 |
+
input_tokens = tokenizer.encode(prompt)
|
965 |
+
input_tokens = input_tokens[-max_input_tokens:] # truncate left
|
966 |
+
return torch.LongTensor([input_tokens]).to(self.device)
|
967 |
+
|
968 |
+
def chat(self, query, history, tokenizer, stream=False,
|
969 |
+
generation_config: Optional[GenerationConfig]=None):
|
970 |
+
generation_config = generation_config or self.generation_config
|
971 |
+
input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
|
972 |
+
stopping_criteria = StoppingCriteriaList(
|
973 |
+
[EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '|end|', '<|endoftext|>'], tokenizer)]
|
974 |
+
)
|
975 |
+
|
976 |
+
if stream:
|
977 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
978 |
+
Thread(target=self.generate, kwargs=dict(
|
979 |
+
inputs=input_ids, streamer=streamer,
|
980 |
+
stopping_criteria = stopping_criteria,
|
981 |
+
generation_config=generation_config,
|
982 |
+
)).start()
|
983 |
+
return streamer
|
984 |
+
else:
|
985 |
+
outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
|
986 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
987 |
+
return response
|
988 |
+
|
989 |
+
def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
|
990 |
+
generation_config = generation_config or self.generation_config
|
991 |
+
max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
|
992 |
+
|
993 |
+
input_ids = tokenizer.encode(prompt)
|
994 |
+
input_ids = input_ids[-max_input_tokens:] # truncate left
|
995 |
+
|
996 |
+
stopping_criteria = StoppingCriteriaList(
|
997 |
+
[EndOfFunctionCriteria([len(input_ids[0])], ['|end|', '|<end>|', '<|endoftext|>'], tokenizer)]
|
998 |
+
)
|
999 |
+
|
1000 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
1001 |
+
Thread(target=self.generate, kwargs=dict(
|
1002 |
+
inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
|
1003 |
+
)).start()
|
1004 |
+
return streamer
|
1005 |
+
|
1006 |
+
|
1007 |
+
class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
1008 |
+
def __init__(self, config):
|
1009 |
+
CodeShellPreTrainedModel.__init__(self, config)
|
1010 |
+
self.transformer = CodeShellModel(config)
|
1011 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1012 |
+
|
1013 |
+
try:
|
1014 |
+
import bitsandbytes
|
1015 |
+
from .quantizer import quantize_offline
|
1016 |
+
quantize_offline(self)
|
1017 |
+
except ImportError:
|
1018 |
+
raise ImportError(f"Needs bitsandbytes to run quantize.")
|
1019 |
+
|
1020 |
+
self.post_init()
|
1021 |
+
|
1022 |
+
@classmethod
|
1023 |
+
def from_pretrained(
|
1024 |
+
cls,
|
1025 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
1026 |
+
*model_args,
|
1027 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
1028 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
1029 |
+
ignore_mismatched_sizes: bool = False,
|
1030 |
+
force_download: bool = False,
|
1031 |
+
local_files_only: bool = False,
|
1032 |
+
token: Optional[Union[str, bool]] = None,
|
1033 |
+
revision: str = "main",
|
1034 |
+
use_safetensors: bool = None,
|
1035 |
+
**kwargs,
|
1036 |
+
):
|
1037 |
+
if not isinstance(config, PretrainedConfig):
|
1038 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
1039 |
+
config, _ = cls.config_class.from_pretrained(
|
1040 |
+
config_path,
|
1041 |
+
cache_dir=cache_dir,
|
1042 |
+
return_unused_kwargs=True,
|
1043 |
+
force_download=force_download,
|
1044 |
+
resume_download=False,
|
1045 |
+
proxies=None,
|
1046 |
+
local_files_only=local_files_only,
|
1047 |
+
token=token,
|
1048 |
+
revision=revision,
|
1049 |
+
subfolder="",
|
1050 |
+
_from_auto=False,
|
1051 |
+
_from_pipeline=None,
|
1052 |
+
**kwargs,
|
1053 |
+
)
|
1054 |
+
|
1055 |
+
# Load config if we don't provide a configuration
|
1056 |
+
from .quantizer import load_state_dict_for_qunantied_model
|
1057 |
+
model = cls(config)
|
1058 |
+
state_dict = torch.load(os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin'), map_location="cpu")
|
1059 |
+
model = load_state_dict_for_qunantied_model(model, state_dict)
|
1060 |
+
model.eval()
|
1061 |
+
|
1062 |
+
# If it is a model with generation capabilities, attempt to load the generation config
|
1063 |
+
if model.can_generate():
|
1064 |
+
try:
|
1065 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
1066 |
+
pretrained_model_name_or_path,
|
1067 |
+
cache_dir=cache_dir,
|
1068 |
+
force_download=force_download,
|
1069 |
+
resume_download=False,
|
1070 |
+
proxies=None,
|
1071 |
+
local_files_only=local_files_only,
|
1072 |
+
token=token,
|
1073 |
+
revision=revision,
|
1074 |
+
subfolder="",
|
1075 |
+
_from_auto=False,
|
1076 |
+
_from_pipeline=None,
|
1077 |
+
**kwargs,
|
1078 |
+
)
|
1079 |
+
except (OSError, TypeError):
|
1080 |
+
pass
|
1081 |
+
|
1082 |
+
device_map = kwargs.pop("device_map", None)
|
1083 |
+
if device_map is not None:
|
1084 |
+
model = model.to(torch.device(device_map))
|
1085 |
+
|
1086 |
+
return model
|