shariar076 commited on
Commit
d8188c4
·
verified ·
1 Parent(s): 73f8750

Create bn_llm_wrapper

Browse files
Files changed (1) hide show
  1. bn_llm_wrapper +58 -0
bn_llm_wrapper ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
4
+
5
+ model_path = os.environ.get("HF_REPO_ID")
6
+ access_token = os.environ.get("HF_TOKEN")
7
+
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=access_token)
10
+
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ # load_in_8bit=use_8_bit,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=getattr(torch, "bfloat16"),
16
+ bnb_4bit_use_double_quant=True,
17
+ )
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(model_path, token=access_token,
20
+ quantization_config=bnb_config,
21
+ torch_dtype=torch.float16,
22
+ # attn_implementation="flash_attention_2",
23
+ device_map='auto')
24
+
25
+ if torch.cuda.is_available():
26
+ device = "cuda"
27
+ else:
28
+ device = "cpu"
29
+
30
+ def generate(
31
+ question,
32
+ context=None,
33
+ temperature=0.7,
34
+ top_p=0.7,
35
+ top_k=40,
36
+ num_beams=4,
37
+ max_new_tokens=256,):
38
+ prompt = f"### CONTEXT:\n{context}\n\n### QUESTION:\n{question}\n\n### ANSWER:"
39
+ inputs = tokenizer(prompt, return_tensors="pt")
40
+ input_ids = inputs["input_ids"].to(device)
41
+ generation_config = GenerationConfig(
42
+ temperature=temperature,
43
+ top_p=top_p,
44
+ top_k=top_k,
45
+ num_beams=num_beams,
46
+ )
47
+ # with torch.autocast("cuda"):
48
+ with torch.no_grad():
49
+ generation_output = model.generate(
50
+ input_ids=input_ids,
51
+ generation_config=generation_config,
52
+ return_dict_in_generate=True,
53
+ output_scores=True,
54
+ max_new_tokens=max_new_tokens,
55
+ )
56
+ seq = generation_output.sequences[0]
57
+ output = tokenizer.decode(seq)
58
+ return output