File size: 2,578 Bytes
9784a5c
 
4f2e221
4eac7e8
4f2e221
9784a5c
4eac7e8
 
 
 
4f2e221
9784a5c
 
4eac7e8
 
 
 
 
 
 
9784a5c
 
4f2e221
9784a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f2e221
4eac7e8
4f2e221
9784a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f2e221
9784a5c
 
 
 
4f2e221
9784a5c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging



logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

class EndpointHandler:
    def __init__(self, path=""):
        # Initialize model and tokenizer
        logger.info("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(".")
        logger.info("tokenizer loaded...")

        self.model = AutoModelForCausalLM.from_pretrained(".")
        logger.info("model loaded...")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data: JSON input with structure:
            {
                "inputs": "your text prompt here",
                "parameters": {
                    "max_new_tokens": 50,
                    "temperature": 0.7,
                    "top_p": 0.9,
                    "do_sample": true
                }
            }
        """
        # Get input text and parameters
        inputs = data.pop("inputs", data)
        logger.info("inputs loaded...", inputs)
        parameters = data.pop("parameters", {})

        # Default generation parameters
        generation_config = {
            "max_new_tokens": parameters.get("max_new_tokens", 50),
            "temperature": parameters.get("temperature", 0.7),
            "top_p": parameters.get("top_p", 0.9),
            "do_sample": parameters.get("do_sample", True),
            "pad_token_id": self.tokenizer.eos_token_id,
            "num_return_sequences": parameters.get("num_return_sequences", 1)
        }

        # Tokenize
        inputs = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.device)

        # Generate text
        with torch.no_grad():
            generated_ids = self.model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                **generation_config
            )

        # Decode and return generated text
        generated_texts = self.tokenizer.batch_decode(
            generated_ids,
            skip_special_tokens=True
        )

        return {
            "generated_text": generated_texts[0],  # Return first generation if multiple
            "all_generations": generated_texts  # All generations if num_return_sequences > 1
        }