MrD05 commited on
Commit
897e575
1 Parent(s): fa97f57

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +68 -0
handler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ from langchain.llms import HuggingFacePipeline
3
+ from langchain import PromptTemplate, LLMChain
4
+
5
+ template = """{char_name}'s Persona: {char_persona}
6
+ <START>
7
+ {chat_history}
8
+ {char_name}: {char_greeting}
9
+ <END>
10
+ {user_name}: {user_input}
11
+ {char_name}: """
12
+
13
+ class EndpointHandler():
14
+
15
+ def __init__(self, path=""):
16
+ tokenizer = AutoTokenizer.from_pretrained(path)
17
+ model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
18
+ local_llm = HuggingFacePipeline(
19
+ pipeline = pipeline(
20
+ "text-generation",
21
+ model = model,
22
+ tokenizer = tokenizer,
23
+ max_length = 2048,
24
+ temperature = 0.5,
25
+ top_p = 0.9,
26
+ top_k = 0,
27
+ repetition_penalty = 1.1,
28
+ pad_token_id = 50256,
29
+ num_return_sequences = 1
30
+ )
31
+ )
32
+ prompt_template = PromptTemplate(
33
+ template = template,
34
+ input_variables = [
35
+ "user_input",
36
+ "user_name",
37
+ "char_name",
38
+ "char_persona",
39
+ "char_greeting",
40
+ "chat_history"
41
+ ],
42
+ validate_template = True
43
+ )
44
+ self.llm_engine = LLMChain(
45
+ llm = local_llm,
46
+ prompt = prompt_template
47
+ )
48
+
49
+ def __call__(self, data):
50
+ inputs = data.pop("inputs", data)
51
+ try:
52
+ response = self.llm_engine.predict(
53
+ user_input = inputs["user_input"],
54
+ user_name = inputs["user_name"],
55
+ char_name = inputs["char_name"],
56
+ char_persona = inputs["char_persona"],
57
+ char_greeting = inputs["char_greeting"],
58
+ chat_history = inputs["chat_history"]
59
+ ).split("\n",1)[0]
60
+ return {
61
+ "inputs": inputs,
62
+ "text": response
63
+ }
64
+ except Exception as e:
65
+ return {
66
+ "inputs": inputs,
67
+ "error": str(e)
68
+ }