mavilov commited on
Commit
ba050cc
·
1 Parent(s): d3f71c9

Rewrite locally

Browse files
Files changed (2) hide show
  1. app.py +35 -15
  2. requirements.txt +8 -7
app.py CHANGED
@@ -1,8 +1,9 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorForSeq2Seq
2
  from trl import SFTTrainer, SFTConfig
3
  from datasets import load_dataset
4
  from peft import LoraConfig, get_peft_model
5
  import torch
 
6
 
7
  # -------------------------
8
  # Load dataset
@@ -13,37 +14,56 @@ dataset = load_dataset("mavilov/convos", split="train")
13
  # Load model and tokenizer
14
  # -------------------------
15
  model_id = "swiss-ai/Apertus-8B-2509"
 
16
 
17
- bnb_config = BitsAndBytesConfig(
18
- load_in_4bit=True,
19
- bnb_4bit_quant_type="nf4",
20
- bnb_4bit_use_double_quant=True,
21
- bnb_4bit_compute_dtype=torch.float16, # FP16 compute
22
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
- quantization_config=bnb_config,
27
- device_map="auto"
28
  )
29
  model.config.use_cache = False
30
  model.config.pretraining_tp = 1
31
 
32
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
33
- tokenizer.pad_token = tokenizer.eos_token
34
-
35
  # -------------------------
36
  # Attach LoRA adapters
37
  # -------------------------
38
  lora_config = LoraConfig(
39
  r=16,
40
  lora_alpha=32,
41
- target_modules=["q_proj", "v_proj"], # typical attention modules
42
  lora_dropout=0.05,
43
  bias="none",
44
  task_type="CAUSAL_LM"
45
  )
46
-
47
  model = get_peft_model(model, lora_config)
48
 
49
  # -------------------------
@@ -76,7 +96,7 @@ training_args = SFTConfig(
76
  num_train_epochs=3,
77
  logging_steps=10,
78
  report_to="tensorboard",
79
- bf16=False, # disable bf16
80
  )
81
 
82
  # -------------------------
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
2
  from trl import SFTTrainer, SFTConfig
3
  from datasets import load_dataset
4
  from peft import LoraConfig, get_peft_model
5
  import torch
6
+ import os
7
 
8
  # -------------------------
9
  # Load dataset
 
14
  # Load model and tokenizer
15
  # -------------------------
16
  model_id = "swiss-ai/Apertus-8B-2509"
17
+ model_kwargs = {}
18
 
19
+ if torch.backends.mps.is_available():
20
+ print("⚡ Using Apple MPS backend (Metal)")
21
+ model_kwargs = {
22
+ "dtype": torch.float16,
23
+ "device_map": {"": "mps"}, # force load directly on MPS
24
+ "offload_folder": "./offload",
25
+ "low_cpu_mem_usage": True, # avoid meta tensors
26
+ }
27
+ elif torch.cuda.is_available():
28
+ print("⚡ Using CUDA with bitsandbytes quantization")
29
+ from transformers import BitsAndBytesConfig
30
+ bnb_config = BitsAndBytesConfig(
31
+ load_in_8bit=True,
32
+ llm_int8_threshold=6.0
33
+ )
34
+ model_kwargs["quantization_config"] = bnb_config
35
+ model_kwargs["device_map"] = "auto"
36
+ else:
37
+ print("⚠️ No GPU/MPS detected, running on CPU (very slow)")
38
+ model_kwargs = {
39
+ "dtype": torch.float32,
40
+ "device_map": {"": "cpu"},
41
+ "low_cpu_mem_usage": True,
42
+ }
43
+
44
+ # Load tokenizer
45
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
46
+ tokenizer.pad_token = tokenizer.eos_token
47
 
48
+ # Load model safely
49
  model = AutoModelForCausalLM.from_pretrained(
50
  model_id,
51
+ **model_kwargs
 
52
  )
53
  model.config.use_cache = False
54
  model.config.pretraining_tp = 1
55
 
 
 
 
56
  # -------------------------
57
  # Attach LoRA adapters
58
  # -------------------------
59
  lora_config = LoraConfig(
60
  r=16,
61
  lora_alpha=32,
62
+ target_modules=["q_proj", "v_proj"],
63
  lora_dropout=0.05,
64
  bias="none",
65
  task_type="CAUSAL_LM"
66
  )
 
67
  model = get_peft_model(model, lora_config)
68
 
69
  # -------------------------
 
96
  num_train_epochs=3,
97
  logging_steps=10,
98
  report_to="tensorboard",
99
+ bf16=False,
100
  )
101
 
102
  # -------------------------
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
- torch>=2.1.0
2
- transformers>=4.35.0
3
- datasets>=2.15.0
4
  accelerate>=0.26.0
5
- trl>=0.7.0
6
- bitsandbytes>=0.41.0
7
- peft>=0.5.0
8
- tensorboard>=2.15.0
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
  accelerate>=0.26.0
5
+ trl
6
+ bitsandbytes
7
+ peft
8
+ tensorboard
9
+ huggingface_hub