shiyemin2 commited on
Commit
042afcb
1 Parent(s): 2a79bc9

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -6
model.py CHANGED
@@ -4,14 +4,25 @@ from typing import Iterator
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
 
7
- model_id = 'LinkSoul/Chinese-Llama-2-7b'
 
 
 
8
 
9
  if torch.cuda.is_available():
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
- device_map='auto'
14
- )
 
 
 
 
 
 
 
 
15
  else:
16
  model = None
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
 
7
+ # Original version
8
+ # model_id = "LinkSoul/Chinese-Llama-2-7b"
9
+ # 4 bit version
10
+ model_id = "LinkSoul/Chinese-Llama-2-7b-4bit"
11
 
12
  if torch.cuda.is_available():
13
+ if model_id.endswith("4bit"):
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_id,
16
+ load_in_4bit=True,
17
+ local_files_only=True,
18
+ torch_dtype=torch.float16
19
+ )
20
+ else:
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch.float16,
24
+ device_map='auto'
25
+ )
26
  else:
27
  model = None
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)