myownskyW7 commited on
Commit
35d5984
1 Parent(s): 1b20edf

Speed up chat

Browse files
Files changed (2) hide show
  1. modeling_InternLM_XComposer.py +7 -0
  2. modeling_utils.py +36 -25
modeling_InternLM_XComposer.py CHANGED
@@ -103,6 +103,13 @@ conversation
103
 
104
  self.eoh = '<TOKENS_UNUSED_0>' # end of human
105
  self.eoa = '<TOKENS_UNUSED_1>' # end of assistant
 
 
 
 
 
 
 
106
 
107
  def maybe_autocast(self, dtype=torch.float16):
108
  # if on cpu, don't use autocast
 
103
 
104
  self.eoh = '<TOKENS_UNUSED_0>' # end of human
105
  self.eoa = '<TOKENS_UNUSED_1>' # end of assistant
106
+ stop_words_ids = [
107
+ torch.tensor([103027]).to(config.device),
108
+ torch.tensor([103028]).to(config.device),
109
+ ]
110
+ stopping_criteria = StoppingCriteriaList(
111
+ [StoppingCriteriaSub(stops=stop_words_ids)])
112
+ self.gen_config['stopping_criteria'] = stopping_criteria
113
 
114
  def maybe_autocast(self, dtype=torch.float16):
115
  # if on cpu, don't use autocast
modeling_utils.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import math
3
  import os
4
  from contextlib import contextmanager
 
5
 
6
  import timm.models.hub as timm_hub
7
  import torch
@@ -32,6 +33,7 @@ def download_cached_file(url, check_hash=True, progress=False):
32
  Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
33
  If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
34
  """
 
35
  def get_cached_file_path():
36
  # a hack to sync the file path across processes
37
  parts = torch.hub.urlparse(url)
@@ -74,49 +76,58 @@ def all_logging_disabled(highest_level=logging.CRITICAL):
74
 
75
 
76
  class LoRALinear(nn.Linear):
77
- def __init__(self,
78
- in_features: int,
79
- out_features: int,
80
- bias: bool = True,
81
- device=None,
82
- dtype=None,
83
- lora_r=8,
84
- lora_alpha=16,
85
- lora_dropout=0.05,
86
- **kwargs) -> None:
 
 
87
  super().__init__(in_features, out_features, bias, device, dtype)
88
  self.lora_r = lora_r
89
  self.lora_alpha = lora_alpha
90
- if lora_dropout > 0.:
91
  self.lora_dropout = nn.Dropout(p=lora_dropout)
92
  else:
93
  self.lora_dropout = lambda x: x
94
  self.lora_scaling = self.lora_alpha / self.lora_r
95
 
96
- self.lora_A = nn.Linear(in_features,
97
- self.lora_r,
98
- bias=False,
99
- device=device,
100
- dtype=dtype)
101
- self.lora_B = nn.Linear(self.lora_r,
102
- out_features,
103
- bias=False,
104
- device=device,
105
- dtype=dtype)
106
 
107
  self.reset_parameters()
108
 
109
  def reset_parameters(self):
110
- if hasattr(self, 'lora_A'):
111
  # initialize A the same way as the default for nn.Linear and B to zero
112
  nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
113
  nn.init.zeros_(self.lora_B.weight)
114
- #print ("lora weight init {} {}".format(torch.mean(self.lora_A.weight), torch.mean(self.lora_B.weight)))
115
 
116
  def forward(self, x):
117
  orig_type = x.dtype
118
  res = super().forward(x)
119
  x = x.float()
120
- res += self.lora_B(self.lora_A(
121
- self.lora_dropout(x))) * self.lora_scaling
122
  return res.to(orig_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import math
3
  import os
4
  from contextlib import contextmanager
5
+ from transformers import StoppingCriteria, StoppingCriteriaList
6
 
7
  import timm.models.hub as timm_hub
8
  import torch
 
33
  Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
34
  If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
35
  """
36
+
37
  def get_cached_file_path():
38
  # a hack to sync the file path across processes
39
  parts = torch.hub.urlparse(url)
 
76
 
77
 
78
  class LoRALinear(nn.Linear):
79
+ def __init__(
80
+ self,
81
+ in_features: int,
82
+ out_features: int,
83
+ bias: bool = True,
84
+ device=None,
85
+ dtype=None,
86
+ lora_r=8,
87
+ lora_alpha=16,
88
+ lora_dropout=0.05,
89
+ **kwargs
90
+ ) -> None:
91
  super().__init__(in_features, out_features, bias, device, dtype)
92
  self.lora_r = lora_r
93
  self.lora_alpha = lora_alpha
94
+ if lora_dropout > 0.0:
95
  self.lora_dropout = nn.Dropout(p=lora_dropout)
96
  else:
97
  self.lora_dropout = lambda x: x
98
  self.lora_scaling = self.lora_alpha / self.lora_r
99
 
100
+ self.lora_A = nn.Linear(
101
+ in_features, self.lora_r, bias=False, device=device, dtype=dtype
102
+ )
103
+ self.lora_B = nn.Linear(
104
+ self.lora_r, out_features, bias=False, device=device, dtype=dtype
105
+ )
 
 
 
 
106
 
107
  self.reset_parameters()
108
 
109
  def reset_parameters(self):
110
+ if hasattr(self, "lora_A"):
111
  # initialize A the same way as the default for nn.Linear and B to zero
112
  nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
113
  nn.init.zeros_(self.lora_B.weight)
 
114
 
115
  def forward(self, x):
116
  orig_type = x.dtype
117
  res = super().forward(x)
118
  x = x.float()
119
+ res += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.lora_scaling
 
120
  return res.to(orig_type)
121
+
122
+
123
+ class StoppingCriteriaSub(StoppingCriteria):
124
+ def __init__(self, stops=[], encounters=1):
125
+ super().__init__()
126
+ self.stops = stops
127
+
128
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
129
+ for stop in self.stops:
130
+ if torch.all((stop == input_ids[:, -len(stop) :])).item():
131
+ return True
132
+
133
+ return False