Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,938 Bytes
4f83ec0 fbe940a 4f83ec0 72a89db 4f83ec0 451395b 028b74d a563465 451395b 028b74d a563465 451395b 028b74d 4f83ec0 72a89db fbe940a 4f83ec0 fbe940a 4f83ec0 61755fe 4f83ec0 3642076 4f83ec0 7375eb9 4f83ec0 72a89db 4f83ec0 a563465 4f83ec0 72a89db 4f83ec0 61755fe 4f83ec0 451395b 4f83ec0 451395b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import json
import logging
import regex
import time
from pathlib import Path
from typing import Annotated, Iterator
import ijson
import outlines
import torch
from pydantic import BaseModel, StringConstraints, conlist, conset
from outlines import generate, models
from outlines.generate.api import SequenceGenerator
from transformers import AutoTokenizer
from fsm import replace_fields
from samplers import PenalizedMultinomialSampler
from utils import StringIteratorIO
logger = logging.getLogger(__name__)
logger.warning("Loading model...")
model_id = "google/gemma-2b-it"
# model_id = "Qwen/Qwen1.5-0.5B-Chat"
if torch.backends.mps.is_available():
device = "mps"
model = models.transformers(model_id, device=device)
else:
device = "cuda"
model = models.transformers(model_id, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampler = PenalizedMultinomialSampler()
low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
sampler.set_max_repeats(empty_tokens, 1)
disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now
disallowed_tokens = [token_id for token_id in range(tokenizer.vocab_size) if any(pattern.match(tokenizer.decode([token_id], skip_special_tokens=True)) for pattern in disallowed_patterns)]
sampler.set_max_repeats(disallowed_tokens, 0)
# This Sample & Dataset models ztr just templated with placeholder fields
class Sample(BaseModel):
# We use get_samples_generator() to replace the placeholder with the requested fields
ABCDabcd12: str
EFGHefgh34: str
IJKLijkl56: str
MNOPmnop78: str
QRSTqrst90: str
# PS: don't use StringConstraints with max_length here since it creates a fsm that is too big
class Dataset(BaseModel):
# We use get_samples_generator() to set the length to infinity
data: conlist(Sample, min_length=2, max_length=3) # type: ignore
samples_generator_template = generate.json(model, Dataset, sampler=sampler)
class Columns(BaseModel):
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
columns_generator = generate.json(model, Columns, sampler=low_temperature_sampler)
def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
fsm=samples_generator_template.fsm
fsm = replace_fields( # replace the placeholder fields by the real fields
fsm=samples_generator_template.fsm,
model=Sample,
new_fields=new_fields,
tokenizer=tokenizer,
make_infinite_loop=True # to generate as many samples as we want
)
return SequenceGenerator(
fsm=fsm,
model=samples_generator_template.model,
sampler=samples_generator_template.sampler,
device=device
)
@outlines.prompt
def columns_prompt(filename: str):
"""I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
Give an example of column names / columns for this dataset to populate a SQL schema.
Please reply in JSON format and place the columns in a field named "columns".
"""
@outlines.prompt
def samples_prommpt(filename: str, prompt: str, columns: str):
"""I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
Give an example of content using a JSON field named "data" with samples with columns {{ columns }}.
{{ prompt }}
"""
def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
filename = Path(filename).stem
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
_start = time.time()
rng = torch.Generator(device=model.device)
rng.manual_seed(seed)
if not columns:
messages = [
{"role": "user", "content": columns_prompt(filename=filename)}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...")
columns_generator_tokens = columns_generator.stream(text, rng=rng)
for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16):
columns.append(column)
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)")
columns = [
tokenizer.decode(tokenizer.encode(column, add_special_tokens=False)[:len(orig_field)], skip_special_tokens=True)
for column, orig_field in zip(columns, Sample.model_fields)
]
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...")
samples_generator = get_samples_generator(new_fields=columns)
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)")
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...")
messages = [
{"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
samples_generator_tokens = samples_generator.stream(text, rng=rng)
for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
yield json.dumps(sample, ensure_ascii=False) + "\n"
logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)") |