File size: 972 Bytes
944bc62
 
 
 
2e3949d
944bc62
d646830
f159189
d646830
944bc62
 
2e3949d
944bc62
 
 
ad76447
 
 
 
 
 
 
2e3949d
 
ad76447
2e3949d
 
 
 
 
 
 
d646830
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Dict, List, Any
import transformers
import torch

MAX_TOKENS=1024

class EndpointHandler(object):
    def __init__(self, path=''):
        self.pipeline: transformers.Pipeline = transformers.pipeline(
            "text-generation",
            model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
            model_kwargs={"torch_dtype": torch.bfloat16 },
            device_map="auto",
        )

    def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
        """
        :param data:
            inputs: message format
            parameters: parameters for the pipeline
        :return:
        """
        inputs = data.pop("inputs")

        if parameters:= data.pop("parameters") is not None:
            outputs = self.pipeline(
                inputs,
                **parameters
            )
        else:
            outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS)

        return outputs