File size: 5,435 Bytes
4f83ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a563465
 
028b74d
a563465
 
028b74d
a563465
 
028b74d
4f83ec0
 
 
 
 
 
 
 
 
61755fe
 
 
 
 
4f83ec0
 
 
 
 
 
 
 
 
 
 
7375eb9
4f83ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a563465
4f83ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61755fe
 
 
 
 
4f83ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

import json
import logging
import time
from typing import Annotated, Iterator

import ijson
import outlines
import torch
from pydantic import BaseModel, StringConstraints, conlist, conset
from outlines import generate, models
from outlines.generate.api import SequenceGenerator
from transformers import AutoTokenizer

from fsm import replace_fields
from samplers import PenalizedMultinomialSampler
from utils import StringIteratorIO

logger = logging.getLogger(__name__)


logger.warning("Loading model...")
# model_id = "google/gemma-2b-it"
model_id = "Qwen/Qwen1.5-0.5B-Chat"
if torch.backends.mps.is_available():
    device = "mps"
    model = models.transformers(model_id, device=device)
else:
    device = "cuda"
    model = models.transformers(model_id, device=device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
sampler = PenalizedMultinomialSampler()
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
sampler.set_max_repeats(empty_tokens, 1)

# This Sample & Dataset models ztr just templated with placeholder fields

class Sample(BaseModel):
    # We use get_samples_generator() to replace the placeholder with the requested fields
    ABCDabcd12: str
    EFGHefgh34: str
    IJKLijkl56: str
    MNOPmnop78: str
    QRSTqrst90: str
    # PS: don't use StringConstraints with max_length here since it creates a fsm that is too big


class Dataset(BaseModel):
    # We use get_samples_generator() to set the length to infinity
    data: conlist(Sample, min_length=2, max_length=3)  # type: ignore


samples_generator_template = generate.json(model, Dataset, sampler=sampler)

class Columns(BaseModel):
    columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1)  # type: ignore

columns_generator = generate.json(model, Columns, sampler=sampler)

def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
    fsm=samples_generator_template.fsm
    fsm = replace_fields(  # replace the placeholder fields by the real fields
        fsm=samples_generator_template.fsm,
        model=Sample, 
        new_fields=new_fields,
        tokenizer=tokenizer,
        make_infinite_loop=True  # to generate as many samples as we want
    )
    return SequenceGenerator(
        fsm=fsm,
        model=samples_generator_template.model,
        sampler=samples_generator_template.sampler,
        device=device
    )


@outlines.prompt
def columns_prompt(filename: str):
    """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
    Give an example of column names / columns for this dataset to populate a SQL schema.
    Please reply in JSON format and place the columns in a field named "columns".
    """

@outlines.prompt
def samples_prommpt(filename: str, prompt: str, columns: str):
    """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data.
    Give an example of content using a JSON field named "data" with samples with columns {{ columns }}.
    {{ prompt }}
    """

def stream_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
    logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
    _start = time.time()
    rng = torch.Generator(device=model.device)
    rng.manual_seed(seed)
    if not columns:

        messages = [
            {"role": "user", "content": columns_prompt(filename=filename)}
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...")
        columns_generator_tokens = columns_generator.stream(text, rng=rng)
        for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16):
            columns.append(column)
        logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)")

    columns = [
        tokenizer.decode(tokenizer.encode(column, add_special_tokens=False)[:len(orig_field)], skip_special_tokens=True)
        for column, orig_field in zip(columns, Sample.model_fields)
    ]

    logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...")
    samples_generator = get_samples_generator(new_fields=columns)
    logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)")
    logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...")
    messages = [
        {"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    samples_generator_tokens = samples_generator.stream(text, rng=rng)
    for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)):
        yield json.dumps(sample, ensure_ascii=False) + "\n"
    logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")