akhaliq HF Staff commited on
Commit
9eb416b
·
verified ·
1 Parent(s): d57d400

Update Gradio app with multiple files

Browse files
Files changed (1) hide show
  1. models.py +9 -5
models.py CHANGED
@@ -77,11 +77,15 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
77
  input_ids = inputs.input_ids.to(model.device)
78
  attention_mask = inputs.attention_mask.to(model.device)
79
 
 
 
 
80
  # Generate with streaming using yield-based approach
81
  accumulated_text = ""
 
82
 
83
  # Generate tokens incrementally
84
- for _ in range(MAX_NEW_TOKENS):
85
  with torch.no_grad():
86
  outputs = model(
87
  input_ids=input_ids,
@@ -120,9 +124,9 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
120
  input_ids = torch.cat([input_ids, next_token], dim=-1)
121
  attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
122
 
123
- # Stop if we've reached max tokens
124
- if input_ids.shape[-1] >= input_ids.shape[-1] + MAX_NEW_TOKENS:
125
- break
126
 
127
  # Final yield to ensure complete text
128
- yield accumulated_text.strip()
 
 
77
  input_ids = inputs.input_ids.to(model.device)
78
  attention_mask = inputs.attention_mask.to(model.device)
79
 
80
+ # Store initial input length
81
+ initial_length = input_ids.shape[-1]
82
+
83
  # Generate with streaming using yield-based approach
84
  accumulated_text = ""
85
+ generated_tokens = 0
86
 
87
  # Generate tokens incrementally
88
+ while generated_tokens < MAX_NEW_TOKENS:
89
  with torch.no_grad():
90
  outputs = model(
91
  input_ids=input_ids,
 
124
  input_ids = torch.cat([input_ids, next_token], dim=-1)
125
  attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
126
 
127
+ # Increment generated tokens counter
128
+ generated_tokens += 1
 
129
 
130
  # Final yield to ensure complete text
131
+ if accumulated_text:
132
+ yield accumulated_text.strip()