Spaces:
Sleeping
Sleeping
Update Gradio app with multiple files
Browse files
models.py
CHANGED
|
@@ -77,11 +77,15 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
|
|
| 77 |
input_ids = inputs.input_ids.to(model.device)
|
| 78 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
# Generate with streaming using yield-based approach
|
| 81 |
accumulated_text = ""
|
|
|
|
| 82 |
|
| 83 |
# Generate tokens incrementally
|
| 84 |
-
|
| 85 |
with torch.no_grad():
|
| 86 |
outputs = model(
|
| 87 |
input_ids=input_ids,
|
|
@@ -120,9 +124,9 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
|
|
| 120 |
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 121 |
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
|
| 125 |
-
break
|
| 126 |
|
| 127 |
# Final yield to ensure complete text
|
| 128 |
-
|
|
|
|
|
|
| 77 |
input_ids = inputs.input_ids.to(model.device)
|
| 78 |
attention_mask = inputs.attention_mask.to(model.device)
|
| 79 |
|
| 80 |
+
# Store initial input length
|
| 81 |
+
initial_length = input_ids.shape[-1]
|
| 82 |
+
|
| 83 |
# Generate with streaming using yield-based approach
|
| 84 |
accumulated_text = ""
|
| 85 |
+
generated_tokens = 0
|
| 86 |
|
| 87 |
# Generate tokens incrementally
|
| 88 |
+
while generated_tokens < MAX_NEW_TOKENS:
|
| 89 |
with torch.no_grad():
|
| 90 |
outputs = model(
|
| 91 |
input_ids=input_ids,
|
|
|
|
| 124 |
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 125 |
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
|
| 126 |
|
| 127 |
+
# Increment generated tokens counter
|
| 128 |
+
generated_tokens += 1
|
|
|
|
| 129 |
|
| 130 |
# Final yield to ensure complete text
|
| 131 |
+
if accumulated_text:
|
| 132 |
+
yield accumulated_text.strip()
|