File size: 3,455 Bytes
a41011b
 
bbee5ad
a41011b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4f69c
bbee5ad
a41011b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4f69c
bbee5ad
8d4f69c
 
 
bbee5ad
8d4f69c
bbee5ad
8d4f69c
 
 
 
 
 
 
 
 
bbee5ad
8d4f69c
bbee5ad
 
f896f78
bbee5ad
f896f78
bbee5ad
 
f896f78
 
 
 
 
 
bbee5ad
f896f78
 
a41011b
 
8d4f69c
bbee5ad
 
8d4f69c
bbee5ad
a41011b
f896f78
 
 
 
 
a41011b
8d4f69c
bbee5ad
8d4f69c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import json

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initializes the model and tokenizer.
        """
        model_name = "patrikpavlov/llama-finance-sentiment"
        
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token # Important for generation
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )

    def __call__(self, data: dict) -> list:
        """
        Handles an incoming request, runs inference, and returns the response.
        """
        inputs = data.pop("inputs", "")
        if not inputs:
            return [{"error": "Input 'inputs' is required."}]

        parameters = data.pop("parameters", {"max_new_tokens": 50})

        # --- NEW PROMPT STRATEGY ---
        # We give a very specific instruction and a schema for the model to follow.
        # This is a more reliable way to get JSON output than the 'response_format' parameter.
        prompt = f"""
        Analyze the sentiment of the financial news text provided. You must respond with only a valid JSON object. Do not add any other text before or after the JSON.

        The JSON object must follow this exact schema:
        {{
          "sentiment": "string"
        }}

        The value for "sentiment" must be one of the following three strings: "Positive", "Negative", or "Neutral".

        Here is the financial news text to analyze:
        ---
        {inputs}
        ---
        """

        messages = [
            {"role": "user", "content": prompt}
        ]
        
        chat_prompt = self.tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )

        input_ids = self.tokenizer(
            chat_prompt, 
            return_tensors="pt"
        ).input_ids.to(self.model.device)

        with torch.no_grad():
            # Generate the text without the failing 'response_format' argument
            output_tokens = self.model.generate(
                input_ids,
                **parameters
            )
        
        newly_generated_tokens = output_tokens[0][len(input_ids[0]):]
        generated_text = self.tokenizer.decode(
            newly_generated_tokens, 
            skip_special_tokens=True
        )
        
        # Clean up and parse the generated text to find the JSON
        try:
            # Find the start and end of the JSON object
            json_start = generated_text.find('{')
            json_end = generated_text.rfind('}') + 1
            if json_start != -1 and json_end != -1:
                json_string = generated_text[json_start:json_end]
                json_output = json.loads(json_string)
                return [json_output]
            else:
                raise ValueError("No JSON object found in the output.")
        except (json.JSONDecodeError, ValueError) as e:
            return [{"error": f"Failed to parse JSON from model output: {e}", "raw_output": generated_text}]