tykiww commited on
Commit
3857382
1 Parent(s): bd67bb6

Create model.py

Browse files
Files changed (1) hide show
  1. connections/model.py +31 -0
connections/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import transformers
3
+ import torch
4
+
5
+
6
+ class InferencePipeline:
7
+ def __init__(self, conf, api_key):
8
+ self.conf = conf
9
+ self.token = api_key
10
+ self.pipeline = self.get_model()
11
+
12
+ def get_model(self):
13
+
14
+ pipeline = transformers.pipeline(
15
+ "text-generation",
16
+ model=conf["model"]["model_name"],
17
+ model_kwargs={"torch_dtype": torch.bfloat16},
18
+ device_map=conf["model"]["device_map"],
19
+ token=self.token
20
+ )
21
+
22
+ return pipeline
23
+
24
+ def infer(self, prompt):
25
+
26
+ outputs = pipeline(
27
+ prompt,
28
+ max_new_tokens=conf["model"]["max_new_tokens"],
29
+ )
30
+
31
+ return outputs[0]["generated_text"][-1]