crscardellino commited on
Commit
b5e125d
1 Parent(s): 2dac16e

Basic chatbot model

Browse files
Files changed (1) hide show
  1. model.py +131 -0
model.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from typing import Optional, Union
3
+
4
+
5
+ class ChatBot:
6
+ """
7
+ Chatbot based on the notebook [How to generate text: using different
8
+ decoding methods for language generation with
9
+ Transformers](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)
10
+ and the blog post [Create conversational agents using BLOOM:
11
+ Part-1](https://medium.com/@fractal.ai/create-conversational-agents-using-bloom-part-1-63a66e6321c0).
12
+
13
+ This code needs testing, as it is not fitted for a production model.
14
+
15
+ It's a very basic chatbot that uses Causal Language Models from Transformers given an PROMPT.
16
+
17
+ An example of a basic PROMPT is given by the BASE_PROMPT attribute.
18
+
19
+ Parameters
20
+ ----------
21
+ base_model : str | AutoModelForCausalLM
22
+ A name (path in hugging face hub) for a model, or the model itself.
23
+ tokenizer : AutoTokenizer | None
24
+ Needed in case the base_model is a given model, otherwise it will load the same model
25
+ given by the base_model path.
26
+ initial_prompt : str
27
+ A prompt for the model. Should follow the example given in `BASE_PROMPT`
28
+ keep_context : bool
29
+ Whether to accumulate the context as the chatbot is used.
30
+ creative : bool
31
+ Whether to generate text through sampling (with some very basic config)
32
+ or to go with greedy algorithm. Check the notebook "How to generate
33
+ text" (link above) for more information.
34
+ max_tokens : int
35
+ Max number of tokens to generate in the chat.
36
+ human_identifier : str
37
+ The string that will identify the human speaker in the prompt (e.g. HUMAN).
38
+ bot_identifier : str
39
+ The string that will identify the bot speaker in the prompt (e.g. EXPERT).
40
+ """
41
+
42
+ BASE_PROMPT = """
43
+ The following is a conversation with a movie EXPERT.
44
+ The EXPERT helps the HUMAN define their personal preferences and provide
45
+ multiple options to select from, it also helps in selecting the best option.
46
+ The EXPERT is conversational, optimistic, flexible, empathetic, creative and
47
+ humanly in generating responses.
48
+
49
+ HUMAN: Hello, how are you?
50
+ EXPERT: Fine, thanks. I am here to help you by recommending movies.
51
+ """.strip()
52
+
53
+
54
+ def __init__(self,
55
+ base_model: Union[str, AutoModelForCausalLM],
56
+ tokenizer: Optional[AutoTokenizer] = None,
57
+ initial_prompt: Optional[str] = None,
58
+ keep_context: bool = False,
59
+ creative: bool = False,
60
+ max_tokens: int = 50,
61
+ human_identifier: str = "HUMAN",
62
+ bot_identifier: str = "EXPERT"):
63
+ if isinstance(base_model, str):
64
+ self.model = AutoModelForCausalLM.from_pretrained(
65
+ base_model,
66
+ low_cpu_mem_usage=True,
67
+ torch_dtype='auto'
68
+ )
69
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model)
70
+ else:
71
+ assert isinstance(self.tokenizer, AutoTokenizer),\
72
+ "If the base model is given, the tokenizer should be given as well"
73
+ self.model = base_model
74
+ self.tokenizer = tokenizer
75
+
76
+ self.initial_prompt = initial_prompt if initial_prompt is not None else self.BASE_PROMPT
77
+ self.keep_context = keep_context
78
+ self.context = ''
79
+ self.creative = creative
80
+ self.max_tokens = max_tokens
81
+ self.human_identifier = human_identifier
82
+ self.bot_identifier = bot_identifier
83
+
84
+ def chat(self, input_text):
85
+ """
86
+ Generates a response from the prompt (and optionally the context) where
87
+ it adds the `input_text` as if it was part of the HUMAN dialog
88
+ (identified by `self.human_identifier`), and prompts the bot (identified
89
+ by `self.bot_identifier`) for a response. As the bot might continue the
90
+ conversation beyond the scope, it trims the output so it only shows the
91
+ first dialog given by the bot, following the idea presented in the
92
+ Medium blog post for creating conversational agents (link above).
93
+
94
+ Parameters
95
+ ----------
96
+ input_text : str
97
+ The question asked/phrase prompted by a human.
98
+
99
+ Returns
100
+ -------
101
+ str
102
+ The output given by the bot, trimmed for better control.
103
+ """
104
+ prompt = self.initial_prompt + self.context
105
+ prompt += f'{self.human_identifier}: {input_text}\n'
106
+ prompt += f'{self.bot_identifier}: '
107
+
108
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
109
+ if self.creative:
110
+ output = self.model.generate(
111
+ input_ids,
112
+ do_sample=True,
113
+ max_length=input_ids.shape[1] + self.max_tokens,
114
+ top_k=50,
115
+ top_p=0.95,
116
+ num_return_sequences=1
117
+ )[0]
118
+ else:
119
+ output = self.model.generate(
120
+ input_ids,
121
+ max_length=input_ids.shape[1] + self.max_tokens
122
+ )[0]
123
+
124
+ decoded_output = self.tokenizer.decode(output, skip_special_tokens=True)
125
+ trimmed_output = decoded_output[len(prompt):]
126
+ trimmed_output = trimmed_output[:trimmed_output.find(f'{self.human_identifier}:')]
127
+
128
+ if self.keep_context:
129
+ self.context += trimmed_output
130
+
131
+ return trimmed_output.strip()