jaymojnidar commited on
Commit
b349bb2
1 Parent(s): a0100cd

loading the model in CPU mode

Browse files
Files changed (1) hide show
  1. model.py +12 -4
model.py CHANGED
@@ -3,29 +3,37 @@ from threading import Thread
3
  from typing import Iterator
4
 
5
  import torch
6
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
  from huggingface_hub import login
8
 
9
  model_id = 'jaymojnidar/Llama-2-7b-chat-hf-sharded-bf16-5GBMAX'
10
 
11
  if not torch.cuda.is_available():
12
  tok = os.environ['HF_TOKEN']
 
 
 
 
 
 
 
13
  login(new_session=True,
14
  write_permission=False,
15
  token=tok
16
 
17
  #, token="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"
18
  )
19
-
 
20
  config = AutoConfig.from_pretrained(model_id,
21
  use_auth_token=True)
22
  config.pretraining_tp = 1
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_id,
25
  config=config,
 
26
  torch_dtype=torch.float16,
27
- load_in_4bit=True,
28
- device_map='auto',
29
  use_auth_token=True
30
  )
31
  else:
 
3
  from typing import Iterator
4
 
5
  import torch
6
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
  from huggingface_hub import login
8
 
9
  model_id = 'jaymojnidar/Llama-2-7b-chat-hf-sharded-bf16-5GBMAX'
10
 
11
  if not torch.cuda.is_available():
12
  tok = os.environ['HF_TOKEN']
13
+ device_map = {
14
+ "transformer.word_embeddings": 0,
15
+ "transformer.word_embeddings_layernorm": 0,
16
+ "lm_head": "cpu",
17
+ "transformer.h": 0,
18
+ "transformer.ln_f": 0,
19
+ }
20
  login(new_session=True,
21
  write_permission=False,
22
  token=tok
23
 
24
  #, token="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"
25
  )
26
+ quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
27
+
28
  config = AutoConfig.from_pretrained(model_id,
29
  use_auth_token=True)
30
  config.pretraining_tp = 1
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_id,
33
  config=config,
34
+ quantization_config=quantization_config,
35
  torch_dtype=torch.float16,
36
+ device_map=device_map,
 
37
  use_auth_token=True
38
  )
39
  else: