File size: 2,602 Bytes
5954dba
 
 
 
 
0386732
5954dba
 
 
 
0386732
 
 
 
 
 
 
 
 
5954dba
 
 
 
0386732
5954dba
0386732
 
 
5954dba
 
 
 
0386732
5954dba
 
 
 
 
 
 
 
 
0386732
5954dba
 
 
 
0386732
 
 
 
 
 
 
 
5954dba
 
0386732
 
5954dba
 
 
 
 
 
 
 
 
 
0386732
5954dba
 
 
 
0386732
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


SYSTEM_PROMPT = """You are Axis, a private personal AI assistant. You are direct, efficient, and no-nonsense. You handle emails, manage calendars, remember everything users tell you, search the web, generate images, and answer questions. Your responses are concise and helpful. You never pretend to be human. Privacy is your core value."""


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(
            path,
            trust_remote_code=True,
            use_fast=True
        )

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )

        self.model.eval()

        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device_map="auto"
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})

        if isinstance(inputs, list):
            messages = inputs
        else:
            messages = [{"role": "user", "content": str(inputs)}]

        if not any(m.get("role") == "system" for m in messages):
            messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages

        try:
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        except Exception:
            prompt = SYSTEM_PROMPT + "\n\nUser: " + str(inputs) + "\nAssistant:"

        max_new_tokens = parameters.get("max_new_tokens", 512)
        temperature = parameters.get("temperature", 0.7)
        top_p = parameters.get("top_p", 0.9)
        repetition_penalty = parameters.get("repetition_penalty", 1.1)

        output = self.pipeline(
            prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            return_full_text=False,
            pad_token_id=self.tokenizer.eos_token_id
        )

        response_text = output[0]["generated_text"].strip()

        return [{"generated_text": response_text}]