Ashishkr commited on
Commit
3697a24
1 Parent(s): b608f8b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +31 -11
model.py CHANGED
@@ -1,25 +1,45 @@
1
  from threading import Thread
2
  from typing import Iterator
3
-
4
  import torch
5
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
 
 
 
7
 
8
  token = os.environ.get("HF_API_TOKEN")
9
 
10
- model_id = 'Ashishkr/llama2_medical_consultation'
11
 
12
- from peft import PeftModel, PeftConfig
13
- from transformers import AutoModelForCausalLM
14
- from transformers import AutoTokenizer
15
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
- config = PeftConfig.from_pretrained("Ashishkr/llama2_medical_consultation")
20
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token = token)
21
- model = PeftModel.from_pretrained(model, "Ashishkr/llama2_medical_consultation").to(device)
22
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token = token)
23
 
24
 
25
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
 
1
  from threading import Thread
2
  from typing import Iterator
 
3
  import torch
4
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  import os
6
+ import transformers
7
+ from torch import cuda, bfloat16
8
+ from peft import PeftModel, PeftConfig
9
 
10
  token = os.environ.get("HF_API_TOKEN")
11
 
12
+ base_model_id = 'meta-llama/Llama-2-7b-chat-hf'
13
 
14
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
15
+
16
+ bnb_config = transformers.BitsAndBytesConfig(
17
+ llm_int8_enable_fp32_cpu_offload = True
18
+ )
19
+
20
+ model_config = transformers.AutoConfig.from_pretrained(
21
+ base_model_id,
22
+ use_auth_token=token
23
+ )
24
+
25
+ model = transformers.AutoModelForCausalLM.from_pretrained(
26
+ base_model_id,
27
+ trust_remote_code=True,
28
+ config=model_config,
29
+ quantization_config=bnb_config,
30
+ device_map='auto',
31
+ use_auth_token=hf_auth
32
+ )
33
+
34
+ config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation")
35
+ model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device)
36
 
37
+ model.eval()
38
 
39
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
40
+ base_model_id,
41
+ use_auth_token=hf_auth
42
+ )
43
 
44
 
45
  def get_prompt(message: str, chat_history: list[tuple[str, str]],