kz209 commited on
Commit
309f86b
1 Parent(s): f276c92
Files changed (1) hide show
  1. utils/model.py +37 -30
utils/model.py CHANGED
@@ -60,35 +60,42 @@ class Model(torch.nn.Module):
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
  if streaming:
63
- # Process each input separately
64
- for single_input_ids in input_ids:
65
- # Set up the initial generation parameters
66
- gen_kwargs = {
67
- "input_ids": single_input_ids.unsqueeze(0),
68
- "max_new_tokens": max_length,
69
- "do_sample": True,
70
- "temperature": temp,
71
- "eos_token_id": self.tokenizer.eos_token_id,
72
- }
73
-
74
- # Generate and yield tokens one by one
75
- unfinished_sequences = single_input_ids.unsqueeze(0)
76
- while unfinished_sequences.shape[1] < gen_kwargs["max_new_tokens"]:
77
- with torch.no_grad():
78
- output = self.model.generate(**gen_kwargs, max_new_tokens=1, return_dict_in_generate=True, output_scores=True)
79
-
80
- next_token_logits = output.scores[0][0]
81
- next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
82
- unfinished_sequences = torch.cat([unfinished_sequences, next_token], dim=-1)
83
-
84
- # Yield the newly generated token
85
- yield self.tokenizer.decode(next_token[0], skip_special_tokens=True)
86
-
87
- if next_token.item() == self.tokenizer.eos_token_id:
88
- break
89
-
90
- # Update input_ids for the next iteration
91
- gen_kwargs["input_ids"] = unfinished_sequences
 
 
 
 
 
 
 
92
  else:
93
  # Non-streaming generation (unchanged)
94
  outputs = self.model.generate(
@@ -98,4 +105,4 @@ class Model(torch.nn.Module):
98
  temperature=temp,
99
  eos_token_id=self.tokenizer.eos_token_id,
100
  )
101
- return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
  if streaming:
63
+ # Set up the initial generation parameters
64
+ gen_kwargs = {
65
+ "input_ids": input_ids,
66
+ "do_sample": True,
67
+ "temperature": temp,
68
+ "eos_token_id": self.tokenizer.eos_token_id,
69
+ "max_new_tokens": 1, # Generate one token at a time
70
+ "return_dict_in_generate": True,
71
+ "output_scores": True
72
+ }
73
+
74
+ # Generate and yield tokens one by one
75
+ generated_tokens = 0
76
+ batch_size = input_ids.shape[0]
77
+ active_sequences = torch.arange(batch_size)
78
+
79
+ while generated_tokens < max_length and len(active_sequences) > 0:
80
+ with torch.no_grad():
81
+ output = self.model.generate(**gen_kwargs)
82
+
83
+ next_tokens = output.sequences[:, -1].unsqueeze(-1)
84
+
85
+ # Yield the newly generated tokens for each sequence in the batch
86
+ for i, token in zip(active_sequences, next_tokens):
87
+ yield i, self.tokenizer.decode(token[0], skip_special_tokens=True)
88
+
89
+ # Update input_ids for the next iteration
90
+ gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1)
91
+ generated_tokens += 1
92
+
93
+ # Check for completed sequences
94
+ completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1)
95
+ active_sequences = torch.tensor([i for i in active_sequences if i not in completed])
96
+ if len(active_sequences) > 0:
97
+ gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
98
+
99
  else:
100
  # Non-streaming generation (unchanged)
101
  outputs = self.model.generate(
 
105
  temperature=temp,
106
  eos_token_id=self.tokenizer.eos_token_id,
107
  )
108
+ return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)