mikeee commited on
Commit
89cb869
1 Parent(s): de222eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -4
app.py CHANGED
@@ -1,18 +1,45 @@
1
  import os
 
 
2
  # os.system("pip install --upgrade torch transformers sentencepiece scipy cpm_kernels accelerate bitsandbytes loguru")
3
- os.system("pip install transformers loguru")
4
 
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
7
 
 
 
 
 
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b-int4", trust_remote_code=True)
 
 
 
 
9
  logger.debug("load")
10
- model = AutoModel.from_pretrained("THUDM/chatglm2-6b-int4", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
11
  logger.debug("done load")
 
12
  # tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_v2_w")
13
  # model = AutoModelForCausalLM.from_pretrained("openchat/openchat_v2_w", load_in_8bit_fp32_cpu_offload=True, load_in_8bit=True)
14
- model.half()
15
- model = model.eval()
16
 
17
  model_path = model.config._dict['model_name_or_path']
18
  logger.debug(f"{model_path=}")
 
1
  import os
2
+ import time
3
+
4
  # os.system("pip install --upgrade torch transformers sentencepiece scipy cpm_kernels accelerate bitsandbytes loguru")
5
+ os.system("pip install torch transformers sentencepiece loguru")
6
 
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
9
 
10
+ # fix timezone in Linux
11
+ os.environ["TZ"] = "Asia/Shanghai"
12
+ try:
13
+ time.tzset() # type: ignore # pylint: disable=no-member
14
+ except Exception:
15
+ # Windows
16
+ logger.warning("Windows, cant run time.tzset()")
17
+
18
+ model_name = "THUDM/chatglm2-6b-int4" # 3.9G
19
+
20
  tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b-int4", trust_remote_code=True)
21
+
22
+ has_cuda = torch.cuda.is_available()
23
+ # has_cuda = False # force cpu
24
+
25
  logger.debug("load")
26
+ if has_cuda:
27
+ if model_name.endswith("int4"):
28
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
29
+ else:
30
+ model = (
31
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
32
+ )
33
+ else:
34
+ model = AutoModel.from_pretrained(
35
+ model_name, trust_remote_code=True
36
+ ).half() # .float() .half().float()
37
+
38
+ model = model.eval()
39
  logger.debug("done load")
40
+
41
  # tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_v2_w")
42
  # model = AutoModelForCausalLM.from_pretrained("openchat/openchat_v2_w", load_in_8bit_fp32_cpu_offload=True, load_in_8bit=True)
 
 
43
 
44
  model_path = model.config._dict['model_name_or_path']
45
  logger.debug(f"{model_path=}")