RanjithaRuttala's picture
Update handler.py
14032ea verified
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = "/repository"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading tokenizer from {path}...")
self.tokenizer = AutoTokenizer.from_pretrained(path)
# StarCoder2 FIXES
# if self.tokenizer.pad_token is None:
# self.tokenizer.pad_token = self.tokenizer.eos_token
# self.tokenizer.padding_side = "left" # Critical for code completion
# Basic tokenizer fixes only
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"Loading model from {path} on device: {self.device}...")
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16, # ✅ back to float16 from bfloat16
trust_remote_code=True,
device_map="auto",
low_cpu_mem_usage=True
# attn_implementation="flash_attention_2" # ✅ Faster + stable
)
self.model.eval()
print("✅ Model loaded successfully!")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {}) or {}
if not isinstance(inputs, str) or not inputs.strip():
return {"generated_text": ""}
# # ✅ StarCoder2: Add code context prefix
# prompt = f"<fim_prefix>{inputs}<fim_suffix><fim_middle>"
gen_kwargs = {
"max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), # Cap for stability
"temperature": parameters.get("temperature", 0.3),
"top_p": parameters.get("top_p", 0.95),
"top_k": parameters.get("top_k", 50),
"do_sample": parameters.get("do_sample", True),
"repetition_penalty": parameters.get("repetition_penalty", 1.1), # Slightly higher
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
print(f"Generating with parameters: {gen_kwargs}")
# print(f"Prompt length: {len(prompt)} | Gen params: {gen_kwargs}")
# StarCoder2 tokenization
inputs = inputs.strip()
tokenized = self.tokenizer(
# prompt,
inputs,
return_tensors="pt",
truncation=True,
max_length=2048,
padding=True
).to(self.device)
with torch.no_grad():
# Generate ONLY new tokens (not full sequence)
outputs = self.model.generate(
input_ids=tokenized.input_ids,
attention_mask=tokenized.attention_mask,
**gen_kwargs,
use_cache=True
)
# Extract ONLY newly generated tokens
new_tokens = outputs[0][len(tokenized.input_ids[0]):]
generated = self.tokenizer.decode(
new_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
# generated = generated.replace("<fim_middle>", "").replace("<fim_suffix>", "").strip()
return {"generated_text": generated.strip()}