File size: 805 Bytes
3857382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38f32e9
3857382
38f32e9
3857382
 
 
 
 
 
 
be0b5f6
3857382
38f32e9
3857382
590a47c
 
3857382
590a47c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

import transformers
import torch


class InferencePipeline:
    def __init__(self, conf, api_key):
        self.conf = conf
        self.token = api_key
        self.pipeline = self.get_model()

    def get_model(self):

        pipeline = transformers.pipeline(
            "text-generation",
            model=self.conf["model"]["model_name"],
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map=self.conf["model"]["device_map"],
            token=self.token
        )
        
        return pipeline

    def infer(self, prompt):

        outputs = self.pipeline(
            prompt,
            max_new_tokens=self.conf["model"]["max_new_tokens"],
        )
        outputs = outputs[0]["generated_text"][-1]
        outputs = outputs['content']

        return outputs