x54-729 commited on
Commit
90d3af9
1 Parent(s): 31117fd

Small change to chat prompt

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +3 -5
modeling_internlm.py CHANGED
@@ -96,7 +96,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
96
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97
  super().__init__()
98
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
99
- self.register_buffer("inv_freq", inv_freq)
100
 
101
  # Build here to make `torch.jit.trace` work.
102
  self.max_seq_len_cached = max_position_embeddings
@@ -769,9 +769,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
769
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
770
  prompt = ""
771
  for record in history:
772
- prompt += f"""<s><|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
773
- if len(prompt) == 0:
774
- prompt += "<s>"
775
  prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
776
  return tokenizer([prompt], return_tensors="pt")
777
 
@@ -995,4 +993,4 @@ class InternLMForSequenceClassification(InternLMPreTrainedModel):
995
  past_key_values=transformer_outputs.past_key_values,
996
  hidden_states=transformer_outputs.hidden_states,
997
  attentions=transformer_outputs.attentions,
998
- )
 
96
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97
  super().__init__()
98
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
99
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
100
 
101
  # Build here to make `torch.jit.trace` work.
102
  self.max_seq_len_cached = max_position_embeddings
 
769
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
770
  prompt = ""
771
  for record in history:
772
+ prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
 
 
773
  prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
774
  return tokenizer([prompt], return_tensors="pt")
775
 
 
993
  past_key_values=transformer_outputs.past_key_values,
994
  hidden_states=transformer_outputs.hidden_states,
995
  attentions=transformer_outputs.attentions,
996
+ )