| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| |
|
| | def EndpointHandler(path=""): |
| | """ |
| | Inference Endpoints์์ ์ฌ์ฉํ ํธ๋ค๋ฌ ํจ์ |
| | """ |
| | class Handler: |
| | def __init__(self, path=""): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.base_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| | self.adapter_path = path or "./tinyllama-qa-news" |
| |
|
| | |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | self.base_model_id, |
| | torch_dtype=torch.float16 |
| | ) |
| | self.model = PeftModel.from_pretrained(base_model, self.adapter_path) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id) |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.tokenizer.padding_side = "right" |
| |
|
| | def generate_response(self, user_input: str): |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": "๋น์ ์ ๊ณต๊ฐ์ ์ด๊ณ ์ดํด์ฌ ๊น์ ๊ฐ์ ์๋ด ์ฑ๋ด์
๋๋ค. ์ฌ์ฉ์์ ๊ฐ์ ๊ณผ ์๊ฐ์ ์ดํดํ๊ณ ๊ณต๊ฐํ๋ฉฐ, ์ ์ ํ ์กฐ์ธ์ ์ ๊ณตํฉ๋๋ค." |
| | }, |
| | {"role": "user", "content": user_input} |
| | ] |
| |
|
| | prompt = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | return_tensors=None |
| | ) |
| |
|
| | inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=256, |
| | temperature=0.8, |
| | top_p=0.95, |
| | do_sample=True, |
| | pad_token_id=self.tokenizer.eos_token_id, |
| | repetition_penalty=1.1, |
| | length_penalty=1.0, |
| | num_return_sequences=1, |
| | num_beams=1, |
| | early_stopping=False |
| | ) |
| |
|
| | decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | if "<|assistant|>" in decoded: |
| | response = decoded.split("<|assistant|>")[-1].strip() |
| | else: |
| | response = decoded.split(prompt)[-1].strip() |
| |
|
| | if not response: |
| | response = "๊ฐ์ ์ ์ดํดํด ์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค. ๋ ์ด์ผ๊ธฐํ๊ณ ์ถ์ผ์ ๊ฒ ์์ผ์ ๊ฐ์?" |
| |
|
| | return response |
| |
|
| | def __call__(self, data): |
| | if not data or "inputs" not in data: |
| | return {"error": "์
๋ ฅ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค."} |
| |
|
| | user_input = data["inputs"] |
| | response = self.generate_response(user_input) |
| | return {"response": response} |
| | |
| | return Handler(path) |