Spaces:
Sleeping
Sleeping
revert batching
Browse files- generate.py +7 -29
- gradio_app.py +4 -4
generate.py
CHANGED
@@ -3,7 +3,6 @@ import json
|
|
3 |
import logging
|
4 |
import regex
|
5 |
import time
|
6 |
-
from itertools import chain, islice
|
7 |
from pathlib import Path
|
8 |
from typing import Annotated, Iterator
|
9 |
|
@@ -23,16 +22,14 @@ logger = logging.getLogger(__name__)
|
|
23 |
|
24 |
|
25 |
logger.warning("Loading model...")
|
|
|
|
|
26 |
if torch.backends.mps.is_available():
|
27 |
device = "mps"
|
28 |
-
|
29 |
-
batch_size = 1 # batching generates duplicates
|
30 |
else:
|
31 |
device = "cuda"
|
32 |
-
model_id =
|
33 |
-
batch_size = 1 # batching generates duplicates
|
34 |
-
|
35 |
-
model = models.transformers(model_id, device=device)
|
36 |
|
37 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
38 |
sampler = PenalizedMultinomialSampler()
|
@@ -98,24 +95,6 @@ def samples_prommpt(filename: str, prompt: str, columns: str):
|
|
98 |
{{ prompt }}
|
99 |
"""
|
100 |
|
101 |
-
|
102 |
-
def stream_json_objects_from_batched_tokens_generator(batched_tokens_generator: Iterator[list[str]], json_field: str) -> Iterator[dict]:
|
103 |
-
first_batch = next(batched_tokens_generator)
|
104 |
-
batch_size = len(first_batch)
|
105 |
-
streams = [""] * batch_size
|
106 |
-
skips = [0] * batch_size
|
107 |
-
for tokens_batch in chain([first_batch], batched_tokens_generator):
|
108 |
-
for stream_idx, token in enumerate(tokens_batch):
|
109 |
-
streams[stream_idx] += token
|
110 |
-
if '"' in token or "}" in token:
|
111 |
-
try:
|
112 |
-
for stream_sample in islice(ijson.items(StringIteratorIO(streams[stream_idx].__iter__()), json_field + ".item", buf_size=1), skips[stream_idx], None):
|
113 |
-
yield stream_sample
|
114 |
-
skips[stream_idx] = +1
|
115 |
-
except ijson.IncompleteJSONError:
|
116 |
-
pass
|
117 |
-
|
118 |
-
|
119 |
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
120 |
filename = Path(filename).stem
|
121 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
@@ -155,8 +134,7 @@ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int,
|
|
155 |
tokenize=False,
|
156 |
add_generation_prompt=True
|
157 |
)
|
158 |
-
|
159 |
-
|
160 |
-
for _, sample in zip(range(size), stream_json_objects_from_batched_tokens_generator(batched_samples_generator_tokens, json_field=json_field)):
|
161 |
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
162 |
-
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
|
|
3 |
import logging
|
4 |
import regex
|
5 |
import time
|
|
|
6 |
from pathlib import Path
|
7 |
from typing import Annotated, Iterator
|
8 |
|
|
|
22 |
|
23 |
|
24 |
logger.warning("Loading model...")
|
25 |
+
model_id = "google/gemma-2b-it"
|
26 |
+
# model_id = "Qwen/Qwen1.5-0.5B-Chat"
|
27 |
if torch.backends.mps.is_available():
|
28 |
device = "mps"
|
29 |
+
model = models.transformers(model_id, device=device)
|
|
|
30 |
else:
|
31 |
device = "cuda"
|
32 |
+
model = models.transformers(model_id, device=device)
|
|
|
|
|
|
|
33 |
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
sampler = PenalizedMultinomialSampler()
|
|
|
95 |
{{ prompt }}
|
96 |
"""
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
|
99 |
filename = Path(filename).stem
|
100 |
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
|
|
|
134 |
tokenize=False,
|
135 |
add_generation_prompt=True
|
136 |
)
|
137 |
+
samples_generator_tokens = samples_generator.stream(text, rng=rng)
|
138 |
+
for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
|
|
|
139 |
yield json.dumps(sample, ensure_ascii=False) + "\n"
|
140 |
+
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")
|
gradio_app.py
CHANGED
@@ -6,11 +6,11 @@ import io
|
|
6 |
import pandas as pd
|
7 |
import spaces
|
8 |
|
9 |
-
from generate import model_id, stream_jsonl_file
|
10 |
|
11 |
-
MAX_SIZE = 20
|
12 |
DEFAULT_SEED = 42
|
13 |
-
DEFAULT_SIZE =
|
14 |
|
15 |
@spaces.GPU(duration=120)
|
16 |
def stream_output(query: str, continue_content: str = ""):
|
@@ -87,4 +87,4 @@ with gr.Blocks() as demo:
|
|
87 |
generate_more_button.click(stream_more_output, filename_comp, outputs)
|
88 |
|
89 |
|
90 |
-
demo.launch()
|
|
|
6 |
import pandas as pd
|
7 |
import spaces
|
8 |
|
9 |
+
from generate import model_id, stream_jsonl_file
|
10 |
|
11 |
+
MAX_SIZE = 20
|
12 |
DEFAULT_SEED = 42
|
13 |
+
DEFAULT_SIZE = 3
|
14 |
|
15 |
@spaces.GPU(duration=120)
|
16 |
def stream_output(query: str, continue_content: str = ""):
|
|
|
87 |
generate_more_button.click(stream_more_output, filename_comp, outputs)
|
88 |
|
89 |
|
90 |
+
demo.launch()
|