| | |
| | import re |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Custom handler for Jingzong/APAN5560 fine-tuned GPT-2 model. |
| | Matches the training/inference format from GPT2RoleplayModel. |
| | """ |
| | |
| | def __init__(self, path=""): |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.model = AutoModelForCausalLM.from_pretrained(path) |
| | self.model.eval() |
| | |
| | |
| | if self.tokenizer.pad_token_id is None: |
| | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| | |
| | |
| | |
| | @staticmethod |
| | def _strip_special_tokens(text: str) -> str: |
| | bad_tokens = [ |
| | "<s>", "</s>", |
| | "<|user|>", "<|assistant|>", |
| | "<user>", "</user>", |
| | "<assistant>", "</assistant>", |
| | "<sub>", "</sub>", |
| | "<|endoftext|>", |
| | ] |
| | for t in bad_tokens: |
| | text = text.replace(t, "") |
| | return text |
| | |
| | @staticmethod |
| | def _shorten(text: str, max_chars: int = 220) -> str: |
| | """Keep at most 1-2 sentences and hard-limit character length.""" |
| | text = text.replace("\r", " ").replace("\n", " ") |
| | text = re.sub(r"\s+", " ", text).strip() |
| | |
| | sentences = re.split(r"(?<=[.!?])\s+", text) |
| | if not sentences: |
| | return text[:max_chars] |
| | |
| | short = " ".join(sentences[:2]) |
| | |
| | if len(short) > max_chars: |
| | short = short[:max_chars].rsplit(" ", 1)[0] + "..." |
| | |
| | return short |
| | |
| | def _clean_answer(self, raw_answer: str) -> str: |
| | text = self._strip_special_tokens(raw_answer) |
| | text = text.strip().strip('"').strip("'") |
| | text = self._shorten(text) |
| | return text |
| | |
| | |
| | |
| | def __call__(self, data): |
| | """ |
| | Process inference request. |
| | |
| | Expected input format: |
| | { |
| | "inputs": "Hello, how are you?", |
| | "parameters": { |
| | "max_new_tokens": 40, |
| | "temperature": 0.8, |
| | "top_p": 0.9 |
| | } |
| | } |
| | """ |
| | inputs = data.get("inputs", "") |
| | parameters = data.get("parameters", {}) |
| | |
| | |
| | max_new_tokens = parameters.get("max_new_tokens", 40) |
| | temperature = parameters.get("temperature", 0.8) |
| | top_p = parameters.get("top_p", 0.9) |
| | repetition_penalty = parameters.get("repetition_penalty", 1.1) |
| | |
| | |
| | prompt = f"User: {inputs}\nAssistant:" |
| | |
| | |
| | encoded = self.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | add_special_tokens=False, |
| | ) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **encoded, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | ) |
| | |
| | |
| | decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False) |
| | |
| | |
| | raw_answer = decoded[len(prompt):] |
| | clean_answer = self._clean_answer(raw_answer) |
| | |
| | return [{"generated_text": clean_answer}] |