tykiww commited on
Commit
471d3b8
·
verified ·
1 Parent(s): 92daee4

Update connections/model_test.py

Browse files
Files changed (1) hide show
  1. connections/model_test.py +33 -0
connections/model_test.py CHANGED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import torch
2
+ #from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
+ #from peft import PeftConfig, PeftModel
4
+
5
+
6
+ from peft import AutoPeftModelForCausalLM
7
+ from transformers import AutoTokenizer
8
+
9
+
10
+
11
+ class InferencePipeline:
12
+ def __init__(self, conf, api_key):
13
+ self.conf = conf
14
+ self.token = api_key
15
+ self.model, self.tokenizer = self.get_model()
16
+
17
+ def get_model(self):
18
+
19
+ model = AutoPeftModelForCausalLM.from_pretrained(
20
+ self.conf["model"]["model_name"],
21
+ load_in_4bit = self.conf["model"]["load_in_4bit"],
22
+ )
23
+ tokenizer = AutoTokenizer.from_pretrained(self.path)
24
+
25
+ return model, tokenizer
26
+
27
+ def infer(self, prompt):
28
+ inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda")
29
+ outputs = model.generate(**inputs,
30
+ max_new_tokens = self.conf["model"]["max_new_tokens"],
31
+ use_cache = True)
32
+ outputs = tokenizer.batch_decode(outputs)
33
+ return outputs