ToddGoldfarb commited on
Commit
77885a5
1 Parent(s): aed0362

uploading model

Browse files
README.md CHANGED
@@ -1,3 +1,149 @@
1
  ---
2
  license: openrail
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: openrail
3
+ datasets:
4
+ - allenai/soda
5
+ language:
6
+ - en
7
+ pipeline_tag: conversational
8
  ---
9
+ # What is Cadet-Tiny?
10
+
11
+ Inspired by Allen AI's **Cosmo-XL**, **Cadet-Tiny** is a _very small_ conversational model trained off of the **SODA** dataset. **Cadet-Tiny** is intended for inference at the edge (on something as small as a 2GB RAM Raspberry Pi).
12
+
13
+ **Cadet-Tiny** is trained off of the **t5-small** pretrained model from Google, and is, as a result, is about 2% of the size of the **Cosmo-3B** model.
14
+
15
+ This is my first SEQ2SEQ NLP Model I've ever made! I'm very excited to share it here on HuggingFace! :)
16
+
17
+ If you have any questions, or any comments on improvements, please contact me at: **tcgoldfarb@gmail.com**
18
+
19
+
20
+
21
+ # Google Colab Link
22
+
23
+ Here is the link to the Google Colab file, where I walk through the process of training the model and using the SODA public dataset from AI2.
24
+
25
+ https://colab.research.google.com/drive/1cx3Yujr_jGQkseqzXZW-2L0vEyEjds_s?usp=sharing
26
+
27
+ # Get Started With Cadet-Tiny
28
+
29
+ Use the code snippet below to get started with Cadet-Tiny!
30
+
31
+ ```
32
+ import torch
33
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
34
+ import colorful as cf
35
+
36
+ cf.use_true_colors()
37
+ cf.use_style('monokai')
38
+ class CadetTinyAgent:
39
+ def __init__(self):
40
+ print(cf.bold | cf.purple("Waking up Cadet-Tiny..."))
41
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-small", model_max_length=512)
43
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("ToddGoldfarb/Cadet-Tiny", low_cpu_mem_usage=True).to(self.device)
44
+ self.conversation_history = ""
45
+
46
+ def observe(self, observation):
47
+ self.conversation_history = self.conversation_history + observation
48
+ # The number 400 below is just a truncation safety net. It leaves room for 112 input tokens.
49
+ if len(self.conversation_history) > 400:
50
+ self.conversation_history = self.conversation_history[112:]
51
+
52
+ def set_input(self, situation_narrative="", role_instruction=""):
53
+ input_text = "dialogue: "
54
+
55
+ if situation_narrative != "":
56
+ input_text = input_text + situation_narrative
57
+
58
+ if role_instruction != "":
59
+ input_text = input_text + " <SEP> " + role_instruction
60
+
61
+ input_text = input_text + " <TURN> " + self.conversation_history
62
+
63
+ # Uncomment the line below to see what is fed to the model.
64
+ # print(input_text)
65
+
66
+ return input_text
67
+
68
+ def generate(self, situation_narrative, role_instruction, user_response):
69
+ user_response = user_response + " <TURN> "
70
+ self.observe(user_response)
71
+
72
+ input_text = self.set_input(situation_narrative, role_instruction)
73
+
74
+ inputs = self.tokenizer([input_text], return_tensors="pt").to(self.device)
75
+
76
+ # I encourage you to change the hyperparameters of the model! Start by trying to modify the temperature.
77
+ outputs = self.model.generate(inputs["input_ids"], max_new_tokens=512, temperature=1, top_p=.95,
78
+ do_sample=True)
79
+ cadet_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
80
+ added_turn = cadet_response + " <TURN> "
81
+ self.observe(added_turn)
82
+
83
+ return cadet_response
84
+
85
+ def reset_history(self):
86
+ self.conversation_history = []
87
+
88
+ def run(self):
89
+ def get_valid_input(prompt, default):
90
+ while True:
91
+ user_input = input(prompt)
92
+ if user_input in ["Y", "N", "y", "n"]:
93
+ return user_input
94
+ if user_input == "":
95
+ return default
96
+
97
+ while True:
98
+ continue_chat = ""
99
+
100
+ # MODIFY THESE STRINGS TO YOUR LIKING :)
101
+ situation_narrative = "Imagine you are Cadet-Tiny talking to ???."
102
+ role_instruction = "You are Cadet-Tiny, and you are talking to ???."
103
+
104
+ self.chat(situation_narrative, role_instruction)
105
+ continue_chat = get_valid_input(cf.purple("Start a new conversation with new setup? [Y/N]:"), "Y")
106
+ if continue_chat in ["N", "n"]:
107
+ break
108
+
109
+ print(cf.blue("CT: See you!"))
110
+
111
+ def chat(self, situation_narrative, role_instruction):
112
+ print(cf.green(
113
+ "Cadet-Tiny is running! Input [RESET] to reset the conversation history and [END] to end the conversation."))
114
+ while True:
115
+ user_input = input("You: ")
116
+ if user_input == "[RESET]":
117
+ self.reset_history()
118
+ print(cf.green("[Conversation history cleared. Chat with Cadet-Tiny!]"))
119
+ continue
120
+ if user_input == "[END]":
121
+ break
122
+ response = self.generate(situation_narrative, role_instruction, user_input)
123
+ print(cf.blue("CT: " + response))
124
+
125
+
126
+ def main():
127
+ print(cf.bold | cf.blue("LOADING MODEL"))
128
+
129
+ CadetTiny = CadetTinyAgent()
130
+ CadetTiny.run()
131
+
132
+
133
+ if __name__ == '__main__':
134
+ main()
135
+ ```
136
+
137
+ # Citations and Special Thanks
138
+ Special thanks to Hyunwoo Kim for discussing with me the best way to use the SODA dataset. If you haven't looked into their work with SODA, Prosocial-Dialog, or COSMO, I recommend you do so! As well, read the paper on SODA!
139
+ The article is listed below.
140
+
141
+ ```
142
+ @article{kim2022soda,
143
+ title={SODA: Million-scale Dialogue Distillation with Social Commonsense Contextualization},
144
+ author={Hyunwoo Kim and Jack Hessel and Liwei Jiang and Peter West and Ximing Lu and Youngjae Yu and Pei Zhou and Ronan Le Bras and Malihe Alikhani and Gunhee Kim and Maarten Sap and Yejin Choi},
145
+ journal={ArXiv},
146
+ year={2022},
147
+ volume={abs/2212.10465}
148
+ }
149
+ ```
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "t5-base",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 3072,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "relu",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "relu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": false,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "n_positions": 512,
20
+ "num_decoder_layers": 12,
21
+ "num_heads": 12,
22
+ "num_layers": 12,
23
+ "output_past": true,
24
+ "pad_token_id": 0,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "task_specific_params": {
28
+ "summarization": {
29
+ "early_stopping": true,
30
+ "length_penalty": 2.0,
31
+ "max_length": 200,
32
+ "min_length": 30,
33
+ "no_repeat_ngram_size": 3,
34
+ "num_beams": 4,
35
+ "prefix": "summarize: "
36
+ },
37
+ "translation_en_to_de": {
38
+ "early_stopping": true,
39
+ "max_length": 300,
40
+ "num_beams": 4,
41
+ "prefix": "translate English to German: "
42
+ },
43
+ "translation_en_to_fr": {
44
+ "early_stopping": true,
45
+ "max_length": 300,
46
+ "num_beams": 4,
47
+ "prefix": "translate English to French: "
48
+ },
49
+ "translation_en_to_ro": {
50
+ "early_stopping": true,
51
+ "max_length": 300,
52
+ "num_beams": 4,
53
+ "prefix": "translate English to Romanian: "
54
+ }
55
+ },
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.28.1",
58
+ "use_cache": true,
59
+ "vocab_size": 32128
60
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.28.1"
7
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:335213bee0978a4f7e2af43bdd7292484a4c56959030dd348e64083d1dd9796c
3
+ size 891702929
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "name_or_path": "t5-small"}