hackhomer76 commited on
Commit
8ea6169
0 Parent(s):

Add inference endpoint handler

Browse files
Files changed (2) hide show
  1. config.json +11 -0
  2. handler.py +122 -0
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_type": "text-generation",
3
+ "framework": "pytorch",
4
+ "requirements": [
5
+ "torch>=2.0.0",
6
+ "transformers>=4.30.0",
7
+ "opencc>=1.1.1",
8
+ "jieba>=0.42.1"
9
+ ],
10
+ "handler": "handler:EndpointHandler"
11
+ }
handler.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import json
5
+ from typing import Dict, Any
6
+ import numpy as np
7
+ from opencc import OpenCC
8
+ import jieba
9
+ import re
10
+
11
+ class EndpointHandler:
12
+ def __init__(self):
13
+ self.tokenizer = None
14
+ self.model = None
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ self.converter = OpenCC('s2t')
17
+
18
+ def initialize(self, context):
19
+ """初始化模型和 tokenizer"""
20
+ self.tokenizer = AutoTokenizer.from_pretrained(
21
+ "THUDM/chatglm3-6b-base",
22
+ trust_remote_code=True
23
+ )
24
+ self.model = AutoModel.from_pretrained(
25
+ "THUDM/chatglm3-6b-base",
26
+ trust_remote_code=True,
27
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
28
+ ).to(self.device)
29
+ self.model.eval()
30
+
31
+ def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
+ """預處理輸入數據"""
33
+ inputs = data.pop("inputs", data)
34
+
35
+ # 確保輸入格式正確
36
+ if not isinstance(inputs, dict):
37
+ inputs = {"message": inputs}
38
+
39
+ return inputs
40
+
41
+ def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
42
+ """執行推理"""
43
+ message = inputs.get("message", "")
44
+ context = inputs.get("context", "")
45
+
46
+ # 構建提示詞
47
+ prompt = self._build_prompt(context, message)
48
+
49
+ # tokenize
50
+ inputs = self.tokenizer(
51
+ prompt,
52
+ return_tensors="pt",
53
+ add_special_tokens=True,
54
+ truncation=True,
55
+ max_length=2048
56
+ ).to(self.device)
57
+
58
+ # 生成回應
59
+ with torch.no_grad():
60
+ outputs = self.model.generate(
61
+ **inputs,
62
+ max_new_tokens=256,
63
+ temperature=0.7,
64
+ top_p=0.9,
65
+ do_sample=True,
66
+ repetition_penalty=1.2,
67
+ num_beams=4,
68
+ early_stopping=True
69
+ )
70
+
71
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
72
+ response = response.split("芙莉蓮:")[-1].strip()
73
+
74
+ # 後處理回應
75
+ response = self._process_response(response)
76
+
77
+ return {"response": response}
78
+
79
+ def _build_prompt(self, context: str, query: str) -> str:
80
+ """構建提示詞"""
81
+ return f"""你是芙莉蓮,需要遵守以下規則回答:
82
+
83
+ 1. 身份設定:
84
+ - 千年精靈魔法師
85
+ - 態度溫柔但帶著些許嘲諷
86
+ - 說話優雅且有距離感
87
+
88
+ 2. 重要關係:
89
+ - 弗蘭梅是我的師傅
90
+ - 費倫是我的學生
91
+ - 欣梅爾是我的摯友
92
+ - 海塔是我的故友
93
+
94
+ 3. 回答規則:
95
+ - 使用繁體中文
96
+ - 必須提供具體詳細的內容
97
+ - 保持回答的連貫性和完整性
98
+
99
+ 相關資訊:
100
+ {context}
101
+
102
+ 用戶:{query}
103
+ 芙莉蓮:"""
104
+
105
+ def _process_response(self, response: str) -> str:
106
+ """處理回應文本"""
107
+ if not response or not response.strip():
108
+ return "抱歉,我現在有點恍神,請你再問一次好嗎?"
109
+
110
+ # 轉換為繁體
111
+ response = self.converter.convert(response)
112
+
113
+ # 清理和格式化
114
+ response = re.sub(r'\s+', '', response)
115
+ if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')):
116
+ response += '呢。'
117
+
118
+ return response
119
+
120
+ def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
121
+ """後處理輸出數據"""
122
+ return data