betterdataai commited on
Commit
0228f64
·
verified ·
1 Parent(s): 1c5d4d7

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +121 -0
  2. requirements.txt +9 -0
inference.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from unsloth import FastLanguageModel
3
+ from transformers import TextStreamer # if needed elsewhere
4
+
5
+ # Set parameters
6
+ max_seq_length = 4096
7
+ dtype = None
8
+ load_in_4bit = False
9
+
10
+ # Load model and tokenizer once at startup
11
+ model, tokenizer = FastLanguageModel.from_pretrained(
12
+ model_name="betterdataai/large-tabular-model",
13
+ max_seq_length=max_seq_length,
14
+ dtype=dtype,
15
+ load_in_4bit=load_in_4bit,
16
+ )
17
+ FastLanguageModel.for_inference(model)
18
+
19
+ def prompt_transformation(prompt):
20
+ initial_prompt = """
21
+ We have the following natural language query:
22
+ "{}"
23
+
24
+ Transform the above natural language query into a formalized prompt format. The format should include:
25
+
26
+ 1. A sentence summarizing the objective.
27
+ 2. A description of the columns, including their data types and examples.
28
+ 3. Four example rows of the dataset in CSV format.
29
+
30
+ An example of this format is as follows, please only focus on the format, not the content:
31
+
32
+ "You are tasked with generating a synthetic dataset based on the following description. The dataset represents employee information. The dataset should include the following columns:
33
+
34
+ - NAME (String): Employee's full name, consisting of a first and last name (e.g., "John Doe", "Maria Lee", "Wei Zhang").
35
+ - GENDER (String): Employee's gender (e.g., "Male", "Female").
36
+ - EMAIL (String): Employee's email address, following the standard format.
37
+ - CITY (String): City where the employee resides (e.g., "New York", "London", "Beijing").
38
+ - COUNTRY (String): Country where the employee resides (e.g., "USA", "UK", "China").
39
+ - SALARY (Float): Employee's annual salary, a value between 30000 and 150000 (e.g., 55000.0, 75000.0).
40
+
41
+ Here are some examples:
42
+ NAME,GENDER,EMAIL,CITY,COUNTRY,SALARY
43
+ John Doe,Male,john.doe@example.com,New York,USA,56000.0
44
+ Maria Lee,Female,maria.lee@nus.edu.sg,London,UK,72000.0
45
+ Wei Zhang,Male,wei.zhang@meta.com,Beijing,China,65000.0
46
+ Sara Smith,Female,sara.smith@orange.fr,Paris,France,85000.0"
47
+
48
+ Here is the transformed query from the given natural language query:
49
+ """
50
+
51
+ messages = [
52
+ {"role": "system", "content": initial_prompt.format(prompt)},
53
+ {"role": "user", "content": "transform the given natural language text to the designated format"}
54
+ ]
55
+
56
+ inputs = tokenizer.apply_chat_template(
57
+ messages,
58
+ tokenize=True,
59
+ add_generation_prompt=True, # Required for generation
60
+ return_tensors="pt",
61
+ ).to("cuda")
62
+
63
+ output_ids = model.generate(
64
+ input_ids=inputs,
65
+ max_new_tokens=4096,
66
+ use_cache=True,
67
+ temperature=1.5,
68
+ min_p=0.1
69
+ )
70
+
71
+ generated_ids = output_ids[0][inputs.shape[1]:]
72
+ return tokenizer.decode(generated_ids, skip_special_tokens=True)
73
+
74
+ def table_generation(prompt):
75
+ messages = [
76
+ {"role": "system", "content": prompt},
77
+ {"role": "user", "content": "create 20 data rows"}
78
+ ]
79
+
80
+ inputs = tokenizer.apply_chat_template(
81
+ messages,
82
+ tokenize=True,
83
+ add_generation_prompt=True, # Required for generation
84
+ return_tensors="pt",
85
+ ).to("cuda")
86
+
87
+ output_ids = model.generate(
88
+ input_ids=inputs,
89
+ max_new_tokens=4096,
90
+ use_cache=True,
91
+ temperature=1.5,
92
+ min_p=0.1
93
+ )
94
+
95
+ generated_ids = output_ids[0][inputs.shape[1]:]
96
+ return tokenizer.decode(generated_ids, skip_special_tokens=True)
97
+
98
+ def predict(input_data):
99
+ """
100
+ Inference endpoint entry point.
101
+
102
+ Expects input_data as a JSON string or dict with a key "query" that contains the natural language query.
103
+ Returns a JSON string with the generated table.
104
+ """
105
+ try:
106
+ if isinstance(input_data, str):
107
+ data = json.loads(input_data)
108
+ else:
109
+ data = input_data
110
+ user_query = data.get("query", "")
111
+ except Exception:
112
+ return json.dumps({
113
+ "error": "Invalid input format. Please provide a JSON payload with a 'query' field."
114
+ })
115
+
116
+ # Transform the user query into the desired prompt format
117
+ transformed_prompt = prompt_transformation(user_query)
118
+ # Generate the table using the transformed prompt
119
+ generated_table = table_generation(transformed_prompt)
120
+
121
+ return json.dumps({"result": generated_table})
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ unsloth
2
+ pandas
3
+ datasets
4
+ trl
5
+ scipy
6
+ transformers>=4.30.0
7
+ peft>=0.4.0
8
+ accelerate>=0.20.0
9
+ torch>=2.0