crscardellino
commited on
Commit
•
b5e125d
1
Parent(s):
2dac16e
Basic chatbot model
Browse files
model.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
|
5 |
+
class ChatBot:
|
6 |
+
"""
|
7 |
+
Chatbot based on the notebook [How to generate text: using different
|
8 |
+
decoding methods for language generation with
|
9 |
+
Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)
|
10 |
+
and the blog post [Create conversational agents using BLOOM:
|
11 |
+
Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0).
|
12 |
+
|
13 |
+
This code needs testing, as it is not fitted for a production model.
|
14 |
+
|
15 |
+
It's a very basic chatbot that uses Causal Language Models from Transformers given an PROMPT.
|
16 |
+
|
17 |
+
An example of a basic PROMPT is given by the BASE_PROMPT attribute.
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
base_model : str | AutoModelForCausalLM
|
22 |
+
A name (path in hugging face hub) for a model, or the model itself.
|
23 |
+
tokenizer : AutoTokenizer | None
|
24 |
+
Needed in case the base_model is a given model, otherwise it will load the same model
|
25 |
+
given by the base_model path.
|
26 |
+
initial_prompt : str
|
27 |
+
A prompt for the model. Should follow the example given in `BASE_PROMPT`
|
28 |
+
keep_context : bool
|
29 |
+
Whether to accumulate the context as the chatbot is used.
|
30 |
+
creative : bool
|
31 |
+
Whether to generate text through sampling (with some very basic config)
|
32 |
+
or to go with greedy algorithm. Check the notebook "How to generate
|
33 |
+
text" (link above) for more information.
|
34 |
+
max_tokens : int
|
35 |
+
Max number of tokens to generate in the chat.
|
36 |
+
human_identifier : str
|
37 |
+
The string that will identify the human speaker in the prompt (e.g. HUMAN).
|
38 |
+
bot_identifier : str
|
39 |
+
The string that will identify the bot speaker in the prompt (e.g. EXPERT).
|
40 |
+
"""
|
41 |
+
|
42 |
+
BASE_PROMPT = """
|
43 |
+
The following is a conversation with a movie EXPERT.
|
44 |
+
The EXPERT helps the HUMAN define their personal preferences and provide
|
45 |
+
multiple options to select from, it also helps in selecting the best option.
|
46 |
+
The EXPERT is conversational, optimistic, flexible, empathetic, creative and
|
47 |
+
humanly in generating responses.
|
48 |
+
|
49 |
+
HUMAN: Hello, how are you?
|
50 |
+
EXPERT: Fine, thanks. I am here to help you by recommending movies.
|
51 |
+
""".strip()
|
52 |
+
|
53 |
+
|
54 |
+
def __init__(self,
|
55 |
+
base_model: Union[str, AutoModelForCausalLM],
|
56 |
+
tokenizer: Optional[AutoTokenizer] = None,
|
57 |
+
initial_prompt: Optional[str] = None,
|
58 |
+
keep_context: bool = False,
|
59 |
+
creative: bool = False,
|
60 |
+
max_tokens: int = 50,
|
61 |
+
human_identifier: str = "HUMAN",
|
62 |
+
bot_identifier: str = "EXPERT"):
|
63 |
+
if isinstance(base_model, str):
|
64 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
65 |
+
base_model,
|
66 |
+
low_cpu_mem_usage=True,
|
67 |
+
torch_dtype='auto'
|
68 |
+
)
|
69 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
|
70 |
+
else:
|
71 |
+
assert isinstance(self.tokenizer, AutoTokenizer),\
|
72 |
+
"If the base model is given, the tokenizer should be given as well"
|
73 |
+
self.model = base_model
|
74 |
+
self.tokenizer = tokenizer
|
75 |
+
|
76 |
+
self.initial_prompt = initial_prompt if initial_prompt is not None else self.BASE_PROMPT
|
77 |
+
self.keep_context = keep_context
|
78 |
+
self.context = ''
|
79 |
+
self.creative = creative
|
80 |
+
self.max_tokens = max_tokens
|
81 |
+
self.human_identifier = human_identifier
|
82 |
+
self.bot_identifier = bot_identifier
|
83 |
+
|
84 |
+
def chat(self, input_text):
|
85 |
+
"""
|
86 |
+
Generates a response from the prompt (and optionally the context) where
|
87 |
+
it adds the `input_text` as if it was part of the HUMAN dialog
|
88 |
+
(identified by `self.human_identifier`), and prompts the bot (identified
|
89 |
+
by `self.bot_identifier`) for a response. As the bot might continue the
|
90 |
+
conversation beyond the scope, it trims the output so it only shows the
|
91 |
+
first dialog given by the bot, following the idea presented in the
|
92 |
+
Medium blog post for creating conversational agents (link above).
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
input_text : str
|
97 |
+
The question asked/phrase prompted by a human.
|
98 |
+
|
99 |
+
Returns
|
100 |
+
-------
|
101 |
+
str
|
102 |
+
The output given by the bot, trimmed for better control.
|
103 |
+
"""
|
104 |
+
prompt = self.initial_prompt + self.context
|
105 |
+
prompt += f'{self.human_identifier}: {input_text}\n'
|
106 |
+
prompt += f'{self.bot_identifier}: '
|
107 |
+
|
108 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
|
109 |
+
if self.creative:
|
110 |
+
output = self.model.generate(
|
111 |
+
input_ids,
|
112 |
+
do_sample=True,
|
113 |
+
max_length=input_ids.shape[1] + self.max_tokens,
|
114 |
+
top_k=50,
|
115 |
+
top_p=0.95,
|
116 |
+
num_return_sequences=1
|
117 |
+
)[0]
|
118 |
+
else:
|
119 |
+
output = self.model.generate(
|
120 |
+
input_ids,
|
121 |
+
max_length=input_ids.shape[1] + self.max_tokens
|
122 |
+
)[0]
|
123 |
+
|
124 |
+
decoded_output = self.tokenizer.decode(output, skip_special_tokens=True)
|
125 |
+
trimmed_output = decoded_output[len(prompt):]
|
126 |
+
trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')]
|
127 |
+
|
128 |
+
if self.keep_context:
|
129 |
+
self.context += trimmed_output
|
130 |
+
|
131 |
+
return trimmed_output.strip()
|