tykiww commited on
Commit
426e1d8
·
verified ·
1 Parent(s): ce6a328

Create model.py

Browse files
Files changed (1) hide show
  1. connections/model.py +36 -0
connections/model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+
3
+ class InferencePipeline:
4
+ def __init__(self, conf, api_key):
5
+ self.conf = conf
6
+ self.token = api_key
7
+ self.model, self.tokenizer = self.get_model()
8
+
9
+ def get_model(self):
10
+ model, tokenizer = FastLanguageModel.from_pretrained(
11
+ model_name = self.conf["model"]["model_name"],
12
+ max_seq_length = self.conf["model"]["max_seq_length"],
13
+ dtype = self.conf["model"]["dtype"],
14
+ load_in_4bit = self.conf["model"]["load_in_4bit"],
15
+ token = self.token
16
+ )
17
+
18
+ FastLanguageModel.for_inference(model) # Enable native 2x faster inference
19
+ return model, tokenizer
20
+
21
+ def infer(self, prompt):
22
+ inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda")
23
+ outputs = model.generate(**inputs,
24
+ max_new_tokens = self.conf["model"]["max_new_tokens"],
25
+ use_cache = True)
26
+ outputs = tokenizer.batch_decode(outputs)
27
+ return outputs
28
+
29
+
30
+ #pipeline = InferencePipeline(conf,
31
+ # api_key=keys["huggingface"],
32
+ # prompt,
33
+ # context
34
+ # )
35
+ #
36
+ #pipeline.infer(prompt)