kz209 commited on
Commit
f276c92
·
1 Parent(s): 9dfac6e
Files changed (2) hide show
  1. utils/model.py +32 -17
  2. utils/multiple_stream.py +7 -7
utils/model.py CHANGED
@@ -60,27 +60,42 @@ class Model(torch.nn.Module):
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
  if streaming:
63
- # Prepare streamers for each input
64
- streamers = [TextStreamer(self.tokenizer, skip_prompt=True) for _ in content_list]
65
-
66
- # Stream the output token by token for each input text
67
- for i, streamer in enumerate(streamers):
68
- for output in self.model.generate(
69
- input_ids[i].unsqueeze(0), # Process each input separately
70
- max_new_tokens=max_length,
71
- do_sample=True,
72
- temperature=temp,
73
- eos_token_id=self.tokenizer.eos_token_id,
74
- return_dict_in_generate=True,
75
- output_scores=True,
76
- streamer=streamer):
77
- yield output # TextStreamer automatically handles the streaming, no need to manually handle the output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  else:
 
79
  outputs = self.model.generate(
80
  input_ids,
81
  max_new_tokens=max_length,
82
  do_sample=True,
83
  temperature=temp,
84
- eos_token_id=self.tokenizer.eos_token_id
85
  )
86
- return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
 
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
  if streaming:
63
+ # Process each input separately
64
+ for single_input_ids in input_ids:
65
+ # Set up the initial generation parameters
66
+ gen_kwargs = {
67
+ "input_ids": single_input_ids.unsqueeze(0),
68
+ "max_new_tokens": max_length,
69
+ "do_sample": True,
70
+ "temperature": temp,
71
+ "eos_token_id": self.tokenizer.eos_token_id,
72
+ }
73
+
74
+ # Generate and yield tokens one by one
75
+ unfinished_sequences = single_input_ids.unsqueeze(0)
76
+ while unfinished_sequences.shape[1] < gen_kwargs["max_new_tokens"]:
77
+ with torch.no_grad():
78
+ output = self.model.generate(**gen_kwargs, max_new_tokens=1, return_dict_in_generate=True, output_scores=True)
79
+
80
+ next_token_logits = output.scores[0][0]
81
+ next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
82
+ unfinished_sequences = torch.cat([unfinished_sequences, next_token], dim=-1)
83
+
84
+ # Yield the newly generated token
85
+ yield self.tokenizer.decode(next_token[0], skip_special_tokens=True)
86
+
87
+ if next_token.item() == self.tokenizer.eos_token_id:
88
+ break
89
+
90
+ # Update input_ids for the next iteration
91
+ gen_kwargs["input_ids"] = unfinished_sequences
92
  else:
93
+ # Non-streaming generation (unchanged)
94
  outputs = self.model.generate(
95
  input_ids,
96
  max_new_tokens=max_length,
97
  do_sample=True,
98
  temperature=temp,
99
+ eos_token_id=self.tokenizer.eos_token_id,
100
  )
101
+ return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
utils/multiple_stream.py CHANGED
@@ -26,13 +26,13 @@ def stream_data(content_list, model):
26
  # Use the gen method to handle batch generation
27
  while True:
28
  updated = False
29
- for i, content in enumerate(content_list):
30
- try:
31
- word = next(model.gen([content], streaming=True)) # Wrap content in a list to match expected input type
32
- outputs[i] += word
33
- updated = True
34
- except StopIteration:
35
- pass
36
 
37
  if not updated:
38
  break
 
26
  # Use the gen method to handle batch generation
27
  while True:
28
  updated = False
29
+ #for i, content in enumerate(content_list):
30
+ try:
31
+ words = next(model.gen(content_list, streaming=True)) # Wrap content in a list to match expected input type
32
+ outputs = [outputs[i].append(f" {words[i]}") for i in range(len(content_list))]
33
+ updated = True
34
+ except StopIteration:
35
+ pass
36
 
37
  if not updated:
38
  break