Crystalcareai commited on
Commit
3ec0166
1 Parent(s): 120f09f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +26 -21
modeling_quiet.py CHANGED
@@ -1836,7 +1836,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1836
  elif ahead_idx >= self.n_ahead - 1:
1837
  if labels is not None: # we're in the talk phase
1838
  cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
1839
- # print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
1840
  shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
1841
  padding = torch.full_like(
1842
  labels[..., :cur_talk_n],
@@ -1848,44 +1847,50 @@ class QuietForCausalLM(QuietPreTrainedModel):
1848
  [shift_labels, padding],
1849
  dim=-1
1850
  )
1851
-
1852
- # print((new_rm_tokens > self.vocab_size - 1).any().item())
1853
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
1854
-
1855
- # Now safely convert rm tokens to one-hot
1856
  probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
 
1857
  else:
1858
  continue
 
1859
  temperature = self.gumbel_temperature if self.training else 0.001
1860
  prev_sample_probs = sample_probs
1861
  sample_probs = probabilities_2d
 
1862
  if ahead_idx < self.n_ahead - 1 and not skip_sampling:
1863
  probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
1864
  if self.gumbel_detach:
1865
  probabilities_2d = probabilities_2d.detach()
1866
- sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu())
 
1867
  # convert rm logits directly to embeddings
1868
  contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
1869
  contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
1870
  contains_thought = contains_start or contains_end
1871
 
1872
- if not contains_thought:
1873
- with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1874
- inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
1875
- else:
1876
- thought_id = self.start_token_id if contains_start else self.end_token_id
1877
- cur_thought_embedding = start_embedding if contains_start else end_embedding
1878
- if self.use_reparam_for_thought_embeddings:
1879
- inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
1880
- inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
1881
- if contains_start:
1882
- sampled_start = inputs_embeds.clone().detach()
 
 
 
 
1883
  else:
1884
- sampled_end = inputs_embeds.clone().detach()
 
1885
  else:
1886
- inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1887
- inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1888
- inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
 
1889
 
1890
  if len(attention_mask.shape) == 2:
1891
  breakpoint()
 
1836
  elif ahead_idx >= self.n_ahead - 1:
1837
  if labels is not None: # we're in the talk phase
1838
  cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
 
1839
  shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
1840
  padding = torch.full_like(
1841
  labels[..., :cur_talk_n],
 
1847
  [shift_labels, padding],
1848
  dim=-1
1849
  )
 
 
1850
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
 
 
1851
  probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
1852
+ skip_sampling = True
1853
  else:
1854
  continue
1855
+
1856
  temperature = self.gumbel_temperature if self.training else 0.001
1857
  prev_sample_probs = sample_probs
1858
  sample_probs = probabilities_2d
1859
+
1860
  if ahead_idx < self.n_ahead - 1 and not skip_sampling:
1861
  probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
1862
  if self.gumbel_detach:
1863
  probabilities_2d = probabilities_2d.detach()
1864
+ sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu())
1865
+
1866
  # convert rm logits directly to embeddings
1867
  contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
1868
  contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
1869
  contains_thought = contains_start or contains_end
1870
 
1871
+ # Flash Attention modification
1872
+ if self._attn_implementation == "flash_attention_2":
1873
+ probabilities_2d = probabilities_2d.view(batch_size, seq_len, -1)
1874
+
1875
+ if contains_thought:
1876
+ thought_id = self.start_token_id if contains_start else self.end_token_id
1877
+ cur_thought_embedding = start_embedding if contains_start else end_embedding
1878
+
1879
+ if self.use_reparam_for_thought_embeddings:
1880
+ inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
1881
+ inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
1882
+ if contains_start:
1883
+ sampled_start = inputs_embeds.clone().detach()
1884
+ else:
1885
+ sampled_end = inputs_embeds.clone().detach()
1886
  else:
1887
+ inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1888
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1889
  else:
1890
+ with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1891
+ inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
1892
+
1893
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1894
 
1895
  if len(attention_mask.shape) == 2:
1896
  breakpoint()