lchakkei commited on
Commit
08e5d46
1 Parent(s): f707439

Create llm_for_langchain.py

Browse files
Files changed (1) hide show
  1. llm_for_langchain.py +56 -0
llm_for_langchain.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.llms.base import LLM
2
+ from typing import Dict, List, Any, Optional
3
+ import torch,sys,os
4
+ from transformers import AutoTokenizer
5
+
6
+
7
+ class LLM(LLM):
8
+ max_token: int = 4000
9
+ temperature: float = 0.1
10
+ top_p: float = 0.95
11
+ tokenizer: Any
12
+ model: Any
13
+
14
+ def __init__(self, model_name_or_path, bit4=True):
15
+ super().__init__()
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,use_fast=False)
17
+ self.tokenizer.pad_token = self.tokenizer.eos_token
18
+
19
+ if bit4==False:
20
+ from transformers import AutoModelForCausalLM
21
+ self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path,device_map='auto',torch_dtype=torch.bfloat16,load_in_8bit=True)
22
+ self.model.eval()
23
+ else:
24
+ from transformers import BitsAndBytesConfig, AutoModelForCausalLM
25
+ double_quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
26
+ self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', quantization_config=double_quant_config)
27
+ self.model.eval()
28
+
29
+ if torch.__version__ >= "2" and sys.platform != "win32":
30
+ self.model = torch.compile(self.model)
31
+
32
+ @property
33
+ def _llm_type(self) -> str:
34
+ return "Mistral"
35
+
36
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, return_only_outputs = True) -> str:
37
+ print('prompt:',prompt)
38
+ input_ids = self.tokenizer(prompt, return_tensors="pt",add_special_tokens=False).input_ids.to('cuda')
39
+ generate_input = {
40
+ "input_ids":input_ids,
41
+ "max_new_tokens":1024,
42
+ "do_sample":True,
43
+ "top_k":50,
44
+ "top_p":self.top_p,
45
+ "temperature":self.temperature,
46
+ "repetition_penalty":1.2,
47
+ "eos_token_id":self.tokenizer.eos_token_id,
48
+ "bos_token_id":self.tokenizer.bos_token_id,
49
+ "pad_token_id":self.tokenizer.pad_token_id
50
+ }
51
+ generate_ids = self.model.generate(**generate_input)
52
+ generate_ids = [item[len(input_ids[0]):-1] for item in generate_ids]
53
+ result_message = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
54
+ return result_message
55
+
56
+