File size: 3,455 Bytes
a41011b bbee5ad a41011b 8d4f69c bbee5ad a41011b 8d4f69c bbee5ad 8d4f69c bbee5ad 8d4f69c bbee5ad 8d4f69c bbee5ad 8d4f69c bbee5ad f896f78 bbee5ad f896f78 bbee5ad f896f78 bbee5ad f896f78 a41011b 8d4f69c bbee5ad 8d4f69c bbee5ad a41011b f896f78 a41011b 8d4f69c bbee5ad 8d4f69c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import json
class EndpointHandler:
def __init__(self, path=""):
"""
Initializes the model and tokenizer.
"""
model_name = "patrikpavlov/llama-finance-sentiment"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token # Important for generation
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
def __call__(self, data: dict) -> list:
"""
Handles an incoming request, runs inference, and returns the response.
"""
inputs = data.pop("inputs", "")
if not inputs:
return [{"error": "Input 'inputs' is required."}]
parameters = data.pop("parameters", {"max_new_tokens": 50})
# --- NEW PROMPT STRATEGY ---
# We give a very specific instruction and a schema for the model to follow.
# This is a more reliable way to get JSON output than the 'response_format' parameter.
prompt = f"""
Analyze the sentiment of the financial news text provided. You must respond with only a valid JSON object. Do not add any other text before or after the JSON.
The JSON object must follow this exact schema:
{{
"sentiment": "string"
}}
The value for "sentiment" must be one of the following three strings: "Positive", "Negative", or "Neutral".
Here is the financial news text to analyze:
---
{inputs}
---
"""
messages = [
{"role": "user", "content": prompt}
]
chat_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_ids = self.tokenizer(
chat_prompt,
return_tensors="pt"
).input_ids.to(self.model.device)
with torch.no_grad():
# Generate the text without the failing 'response_format' argument
output_tokens = self.model.generate(
input_ids,
**parameters
)
newly_generated_tokens = output_tokens[0][len(input_ids[0]):]
generated_text = self.tokenizer.decode(
newly_generated_tokens,
skip_special_tokens=True
)
# Clean up and parse the generated text to find the JSON
try:
# Find the start and end of the JSON object
json_start = generated_text.find('{')
json_end = generated_text.rfind('}') + 1
if json_start != -1 and json_end != -1:
json_string = generated_text[json_start:json_end]
json_output = json.loads(json_string)
return [json_output]
else:
raise ValueError("No JSON object found in the output.")
except (json.JSONDecodeError, ValueError) as e:
return [{"error": f"Failed to parse JSON from model output: {e}", "raw_output": generated_text}] |