Adapters
Inference Endpoints
File size: 1,658 Bytes
88e1248
dc32044
c8b5fa1
 
 
7bf309f
c8b5fa1
 
dc32044
88e1248
7bf309f
dc32044
 
c8b5fa1
 
 
39ec5b7
c8b5fa1
 
 
 
88e1248
c8b5fa1
dc32044
 
c8b5fa1
dc32044
c8b5fa1
f85d258
dc32044
adf79f2
 
dc32044
7bf309f
dc32044
6dec8ee
dc32044
7bf309f
dc32044
7bf309f
dc32044
88e1248
dc32044
6dec8ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Dict, Any
import logging

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import torch.cuda


LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"


class EndpointHandler():
    def __init__(self, path=""):
        config = PeftConfig.from_pretrained(path)
        model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
        self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
        # Load the Lora model
        self.model = PeftModel.from_pretrained(model, path)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            data (Dict): The payload with the text prompt and generation parameters.
        """
        LOGGER.info(f"Received data: {data}")
        # Get inputs
        prompt = data.pop("inputs", None)
        parameters = data.pop("parameters", None)
        if prompt is None:
            raise ValueError("Missing prompt.")
        # Preprocess
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
        # Forward
        LOGGER.info(f"Start generation.")
        if parameters is not None:
            output = self.model.generate(input_ids=input_ids, **parameters)
        else:
            output = self.model.generate(input_ids=input_ids)
        # Postprocess
        prediction = self.tokenizer.decode(output[0])
        LOGGER.info(f"Generated text: {prediction}")
        return {"generated_text": prediction}