iTab-LLM
iTab-LLM is the Llama-2 7B model further trained with massive tables. This model is pretrained dedicating to solving the predictive tasks related to tabular data. For the details of our model, please refer to our paper: Unleashing the Potential of Large Language Models for Predictive Tabular Tasks in Data Science link
Demo Usage
Classification
from transformers import LlamaForSequenceClassification
model_name_or_path = "OldBirdAZ/itab-llm"
model = LlamaForSequenceClassification.from_pretrained(
model_name_or_path,
num_labels=num_labels,
)
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
Regression
You could build model resemble to LlamaForSequenceClassification, outputing to single numerical value. The model can be finetuned with the optimization of minimizing MSE loss.
Zero-shot Prediction
from transformers import AutoModelForCausalLM
import tensor_parallel as tp
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
use_flash_attention_2="flash_attention_2",
torch_dtype=torch.bfloat16
)
model = tp.tensor_parallel(model, sharded=True)
prompt_str = "YOUR-PROMPT"
# fillin_missing_val_prompt_str = "### Instruction: Please fill in the missing value(s) in the table in Markdown format. The missing values are marked with placeholders: <missing_value_0>, <missing_value_1>, <missing_value_2>, ... The description of this table is: Historical cryptocurrency prices for the top 50 coins, including Open, High, Low, Volume, and Change % for each date.\n\n### Input:\n| low | high | sno | open | vol. | change % | date | price |\n| -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |\n| 59.81 | 63.035 | 768.0 | 61.703 | 4320000.0 | -0.78 | 2018-09-30 | 61.224 |\n| 59.472 | 62.137 | 769.0 | 61.225 | 4480000.0 | -1.34 | 2018-10-01 | 60.401 |\n| 59.231 | 61.835 | 770.0 | 60.392 | 4430000.0 | -1.32 | 2018-10-02 | 59.606 |\n| 56.745 | 59.704 | 771.0 | 59.606 | 4780000.0 | -3.46 | 2018-10-03 | 57.541 |\n| 57.457 | 60.079 | 772.0 | 57.562 | 3260000.0 | 1.51 | 2018-10-04 | 58.411 |\n| 57.672 | 59.82 | 773.0 | 58.41 | 4630000.0 | 1.01 | 2018-10-05 | 59.001 |\n| 56.692 | 59.169 | 774.0 | 59.065 | 4730000.0 | -1.78 | 2018-10-06 | 57.951 |\n| 56.986 | 58.533 | 775.0 | 57.951 | 1500000.0 | 0.56 | 2018-10-07 | 58.273 |\n| 57.693 | 60.163 | 776.0 | 58.274 | 2260000.0 | 2.37 | 2018-10-08 | 59.655 |\n| 58.523 | 59.887 | 777.0 | 59.655 | 2230000.0 | -1.15 | 2018-10-09 | 58.968 |\n| 50.968 | 54.149 | 780.0 | 51.263 | 2170000.0 | 4.99 | 2018-10-12 | 53.806 |\n| 52.233 | 61.175 | 783.0 | 52.545 | 3210000.0 | 6.48 | 2018-10-15 | 55.951 |\n| 54.619 | 56.809 | 784.0 | 55.95 | 1860000.0 | -0.97 | 2018-10-16 | 55.408 |\n| 54.339 | 55.571 | 785.0 | 55.416 | 1960000.0 | 0.14 | 2018-10-17 | 55.484 |\n| 52.965 | 54.47 | 787.0 | 53.431 | 2190000.0 | 0.4 | 2018-10-19 | 53.638 |\n| 53.476 | 54.751 | 789.0 | 54.233 | 2400000.0 | -0.94 | 2018-10-21 | 53.721 |\n| 52.78 | 54.176 | 790.0 | 53.686 | 2240000.0 | -1.15 | 2018-10-22 | 53.105 |\n| 50.648 | 54.694 | 791.0 | 53.122 | 2470000.0 | 0.41 | 2018-10-23 | 53.32 |\n| 52.894 | 53.933 | 792.0 | 53.325 | 2910000.0 | -0.44 | 2018-10-24 | 53.088 |\n| 52.677 | 53.293 | 793.0 | 53.094 | 2100000.0 | -0.16 | 2018-10-25 | 53.003 |\n| 52.382 | 53.447 | 794.0 | 53.003 | 2570000.0 | -0.75 | 2018-10-26 | 52.607 |\n| 51.994 | 53.202 | 795.0 | 52.608 | 2190000.0 | -0.46 | 2018-10-27 | 52.364 |\n| 48.161 | 52.501 | 797.0 | 51.958 | 2710000.0 | -5.14 | 2018-10-29 | 49.314 |\n| 48.915 | 50.071 | 798.0 | 49.307 | 1520000.0 | 0.0 | 2018-10-30 | 49.314 |\n| 48.209 | 50.652 | 799.0 | 49.317 | 2510000.0 | 1.28 | 2018-10-31 | 49.943 |\n| 49.851 | 50.781 | 800.0 | 49.952 | 2050000.0 | 1.31 | 2018-11-01 | <missing_value_0> |\n| 48.241 | 52.113 | 801.0 | 50.595 | 2280000.0 | 2.21 | 2018-11-02 | 51.716 |\n| 48.437 | 56.253 | 803.0 | 51.115 | 2280000.0 | 6.76 | 2018-11-04 | 54.573 |\n| 52.94 | 55.129 | 804.0 | 54.572 | 2020000.0 | -1.84 | 2018-11-05 | 53.57 |\n| 51.256 | 56.477 | 805.0 | 53.567 | 1690000.0 | 5.28 | 2018-11-06 | 56.401 |\n\n### Response: "
input_ids = prompt['input_ids'].to(model.device)
with torch.no_grad():
response_result = model.generate(
input_ids,
max_new_tokens=max_dec_len,
output_scores=True,
return_dict_in_generate=True,
num_return_sequences=1,
remove_invalid_values=True,
)
response = tokenizer.decode(response_result["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True).strip()
result["generated_text"] = response.split("\n")[0].strip()
Ethical Considerations and Limitations
This model is the further pretrained version of Llama-2 7B over tables. Because the pretraining data mainly collected from Kaggle, you are required to rigorously follows Kaggle's terms and licensing agreements, adhering to legal and ethical standards if you would like to use this model. In addition, you also need to adhere the corresponding license and requirement of Llama-2 7B. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, iTab-LLM’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of this model or applications based on this model, developers should perform safety testing and tuning tailored to their specific applications of the model.
- Downloads last month
- 464