rishiraj commited on
Commit
4efce51
1 Parent(s): a229104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -15
app.py CHANGED
@@ -7,17 +7,16 @@ from alignment import (
7
  get_tokenizer,
8
  )
9
 
10
- def template(base_model, trained_adapter, token):
11
- data_args = DataArguments(chat_template=None, dataset_mixer={'HuggingFaceH4/no_robots': 1.0}, dataset_splits=['train_sft', 'test_sft'], max_train_samples=None, max_eval_samples=None, preprocessing_num_workers=12, truncation_side=None)
12
- model_args = ModelArguments(base_model_revision=None, model_name_or_path='mistralai/Mistral-7B-v0.1', model_revision='main', model_code_revision=None, torch_dtype='auto', trust_remote_code=True, use_flash_attention_2=True, use_peft=True, lora_r=64, lora_alpha=16, lora_dropout=0.1, lora_target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], lora_modules_to_save=None, load_in_8bit=False, load_in_4bit=True, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False)
 
13
 
14
  ###############
15
  # Load datasets
16
  ###############
17
  raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
18
- logger.info(
19
- f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
20
- )
21
 
22
  ################
23
  # Load tokenizer
@@ -31,9 +30,15 @@ def template(base_model, trained_adapter, token):
31
  train_dataset = raw_datasets["train"]
32
  eval_dataset = raw_datasets["test"]
33
 
 
 
 
 
 
 
34
  with gr.Blocks() as demo:
35
- gr.Markdown("## AutoTrain Merge Adapter")
36
- gr.Markdown("Please duplicate this space and attach a GPU in order to use it.")
37
  token = gr.Textbox(
38
  label="Hugging Face Write Token",
39
  value="",
@@ -42,23 +47,44 @@ with gr.Blocks() as demo:
42
  interactive=True,
43
  type="password",
44
  )
45
- base_model = gr.Textbox(
46
- label="Base Model (e.g. meta-llama/Llama-2-7b-chat-hf)",
47
- value="",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  lines=1,
49
  max_lines=1,
50
  interactive=True,
51
  )
52
- trained_adapter = gr.Textbox(
53
- label="Trained Adapter Model (e.g. username/autotrain-my-llama)",
54
  value="",
55
  lines=1,
56
  max_lines=1,
57
  interactive=True,
58
  )
59
- submit = gr.Button(value="Merge & Push")
60
  op = gr.Markdown(interactive=False)
61
- submit.click(merge, inputs=[base_model, trained_adapter, token], outputs=[op])
62
 
63
 
64
  if __name__ == "__main__":
 
7
  get_tokenizer,
8
  )
9
 
10
+
11
+ def reformat(dataset_name, train_split, test_split, model_name, upload_name, token):
12
+ data_args = DataArguments(chat_template=None, dataset_mixer={dataset_name: 1.0}, dataset_splits=[train_split, test_split], max_train_samples=None, max_eval_samples=None, preprocessing_num_workers=12, truncation_side=None)
13
+ model_args = ModelArguments(base_model_revision=None, model_name_or_path=model_name, model_revision='main', model_code_revision=None, torch_dtype='auto', trust_remote_code=True, use_flash_attention_2=True, use_peft=True, lora_r=64, lora_alpha=16, lora_dropout=0.1, lora_target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], lora_modules_to_save=None, load_in_8bit=False, load_in_4bit=True, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False)
14
 
15
  ###############
16
  # Load datasets
17
  ###############
18
  raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
19
+ output = f"Dataset successfully formatted and pushed! Dataset and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
 
 
20
 
21
  ################
22
  # Load tokenizer
 
30
  train_dataset = raw_datasets["train"]
31
  eval_dataset = raw_datasets["test"]
32
 
33
+ raw_dataset.push_to_hub(upload_name)
34
+ return gr.Markdown.update(
35
+ value=output
36
+ )
37
+
38
+
39
  with gr.Blocks() as demo:
40
+ gr.Markdown("## Dataset Chat Template")
41
+ gr.Markdown("Format Datasets like HuggingFaceH4/no_robots to be AutoTrain compatible.")
42
  token = gr.Textbox(
43
  label="Hugging Face Write Token",
44
  value="",
 
47
  interactive=True,
48
  type="password",
49
  )
50
+ dataset_name = gr.Textbox(
51
+ label="Dataset Name (e.g. HuggingFaceH4/no_robots)",
52
+ value="HuggingFaceH4/no_robots",
53
+ lines=1,
54
+ max_lines=1,
55
+ interactive=True,
56
+ )
57
+ train_split = gr.Textbox(
58
+ label="Train Split Name (e.g. train_sft)",
59
+ value="train_sft",
60
+ lines=1,
61
+ max_lines=1,
62
+ interactive=True,
63
+ )
64
+ test_split = gr.Textbox(
65
+ label="Test Split Name (e.g. test_sft)",
66
+ value="test_sft",
67
+ lines=1,
68
+ max_lines=1,
69
+ interactive=True,
70
+ )
71
+ model_name = gr.Textbox(
72
+ label="Model Name (e.g. mistralai/Mistral-7B-v0.1)",
73
+ value="mistralai/Mistral-7B-v0.1",
74
  lines=1,
75
  max_lines=1,
76
  interactive=True,
77
  )
78
+ upload_name = gr.Textbox(
79
+ label="Your Dataset Name (e.g. rishiraj/no_robots)",
80
  value="",
81
  lines=1,
82
  max_lines=1,
83
  interactive=True,
84
  )
85
+ submit = gr.Button(value="Apply Template & Push")
86
  op = gr.Markdown(interactive=False)
87
+ submit.click(reformat, inputs=[dataset_name, train_split, test_split, model_name, upload_name, token], outputs=[op])
88
 
89
 
90
  if __name__ == "__main__":