vietlethe commited on
Commit
6dc4993
1 Parent(s): 0a442ff

first commit

Browse files
Files changed (9) hide show
  1. README.md +76 -10
  2. app.py +318 -0
  3. config.py +1 -0
  4. dockerfile +14 -0
  5. download_dependencies.py +8 -0
  6. errant_verbose.json +86 -0
  7. finetuning_tinyllama.py +160 -0
  8. requirements.txt +9 -0
  9. utils.py +193 -0
README.md CHANGED
@@ -1,10 +1,76 @@
1
- ---
2
- title: T5nyllama
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Lightweight English Text Editing Assistant (t5nyllama)
2
+ This repository houses the source code for t5nyllama, a lightweight English text editing assistant designed to provide a simple and efficient way to enhance your writing.
3
+
4
+ **Huggingface Spaces:**
5
+ https://huggingface.co/spaces/letheviet/t5nyllama
6
+
7
+ **How it Works:**
8
+
9
+ t5nyllama uses a two-step approach:
10
+
11
+ 1. **Text Generation:** The core of the assistant is a TinyLlama model, specifically fine-tuned for text editing. This model is designed to improve the flow and clarity of your text, making it more polished and engaging. However, TinyLlama is **relatively small and not particularly adept at complex grammar correction.**
12
+
13
+ 2. **Grammar Correction:** To address this limitation, we employ a powerful Flan-T5 model for a second pass. This model takes the output of the TinyLlama model and carefully analyzes it for grammatical errors. It then suggests corrections, ensuring your final text is grammatically sound and ready for publication.
14
+
15
+ **Key Features:**
16
+
17
+ * **Lightweight and Efficient:** The TinyLlama model is quantized to 4-bit precision, minimizing memory usage and computational demands, making it suitable for resource-constrained environments.
18
+ * **Focused on Text Improvement:** TinyLlama excels at refining the overall quality of your writing, making it more readable and engaging.
19
+ * **Enhanced Grammar Accuracy:** The Flan-T5 model provides a robust final check for grammatical errors, ensuring your text is free from mistakes.
20
+
21
+
22
+ **Design Principles:**
23
+
24
+ * **Local Application:** Prioritizes offline functionality, allowing you to edit text without requiring an internet connection.
25
+ * **Lightweight Design:** Minimizes resource consumption, making the application suitable for a wide range of devices and systems.
26
+
27
+ ## Installation
28
+
29
+ **1. Clone the Repository:**
30
+ ```shell
31
+ git clone https://github.com/LETHEVIET/t5nyllama.git
32
+ ```
33
+
34
+ **2. Install Dependencies:**
35
+ ```shell
36
+ pip3 install -r requirements.txt
37
+ python3 -m spacy download en_core_web_sm
38
+ python3 download_dependencies.py
39
+ ```
40
+
41
+ **3. Run the Application:**
42
+ ```shell
43
+ python3 app.py
44
+ ```
45
+
46
+ ## Docker Deployment
47
+
48
+ **1. Build Docker Image:**
49
+ ```shell
50
+ docker build . -t t5nyllama
51
+ ```
52
+
53
+ **2. Run Docker Image:**
54
+ ```shell
55
+ docker run -p 7860:7860 t5nyllama
56
+ ```
57
+
58
+ ## Fine-Tuning TinyLlama
59
+
60
+ The fine-tuning script follows the UnslothAI example for fine-tuning Tiny Llama. Please install dependencies from [unsloth](https://github.com/unslothai/unsloth) before running the script.
61
+
62
+ ```shell
63
+ python finetuning_tinyllama.py
64
+ ```
65
+
66
+ ## References
67
+
68
+ * **Unsloth Fast Fine-Tuning LLM:** https://github.com/unslothai/unsloth
69
+ * **Dataset Card for CoEdIT: Text Editing via Instruction Tuning :** https://huggingface.co/datasets/grammarly/coedit
70
+ * **Grammar-Synthesis-Large: FLAN-t5:** https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis
71
+ * **ALLECS: A Lightweight Language Error Correction System:** https://github.com/nusnlp/ALLECS
72
+ * **Python Bindings for llama.cpp:** https://github.com/abetlen/llama-cpp-python
73
+ * **Gradio: Build Machine Learning Web Apps — in Python:** https://github.com/gradio-app/gradio
74
+ ## Demo
75
+
76
+ [Include a GIF or screenshot demonstrating the application's functionality.]
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import errant
3
+ import spacy
4
+ import os
5
+ import json
6
+ import nltk
7
+ from utils import get_random_prompt, instruction_prompts
8
+ from llama_cpp import Llama
9
+ from transformers import pipeline
10
+ import config
11
+
12
+ # Load necessary models and resources
13
+ nlp = spacy.load("en_core_web_sm")
14
+ annotator = errant.load('en', nlp)
15
+ errant_path = os.path.join(os.path.dirname("./"), 'errant_verbose.json')
16
+ errant_verbose = json.load(open(errant_path, "r"))
17
+ sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
18
+
19
+ # Load text editor (TinyLlama)
20
+ text_editor = Llama(
21
+ model_path="./texteditor-model/coedit-tinyllama-chat-bnb-4bit-unsloth.Q4_K_M.gguf",
22
+ verbose=False
23
+ )
24
+ print("text editor is loaded!")
25
+
26
+ # Load grammar corrector (Flan-T5)
27
+ grammar_corrector = pipeline(
28
+ 'text2text-generation',
29
+ 'pszemraj/flan-t5-large-grammar-synthesis',
30
+ )
31
+ print("grammar corrector is loaded!")
32
+
33
+ def correcting_text(src: str) -> str:
34
+ """
35
+ Corrects grammatical errors in the given text using the grammar corrector model.
36
+
37
+ Args:
38
+ src: The text to be corrected.
39
+
40
+ Returns:
41
+ The grammatically corrected text.
42
+ """
43
+ lines = src.split('\n')
44
+ sentences = []
45
+ line_idx = []
46
+ for l_idx, line in enumerate(lines):
47
+ if len(line) == 0:
48
+ continue
49
+ l_sents = sent_detector.tokenize(line)
50
+ for sent in l_sents:
51
+ sentences.append(sent)
52
+ line_idx.append(l_idx)
53
+
54
+ num_iter = (len(sentences) + config.BATCH_SIZE - 1) // config.BATCH_SIZE
55
+ final_outs = []
56
+ out_lines = ["" for _ in lines]
57
+ for i in range(num_iter):
58
+ start = i * config.BATCH_SIZE
59
+ end = min((i + 1) * config.BATCH_SIZE, len(sentences))
60
+
61
+ final_outs += grammar_corrector(sentences[start:end], max_length=128, num_beams=5, early_stopping=True)
62
+
63
+
64
+ for i in range(len(final_outs)):
65
+ out_lines[line_idx[i]] += final_outs[i]["generated_text"] + " "
66
+
67
+ return "\n".join(out_lines)
68
+
69
+ def annotate_text(src: str, tag: str, analyze: bool = True) -> list:
70
+ """
71
+ Annotates the text with edits based on the provided tag using the Errant library.
72
+ original code from: https://github.com/nusnlp/ALLECS
73
+ Args:
74
+ src: The source text.
75
+ tag: The target text.
76
+ analyze: Whether to analyze and provide detailed information about edits.
77
+
78
+ Returns:
79
+ A list of tuples representing the edits, where each tuple is:
80
+ - (edit_text, edit_type)
81
+ """
82
+ out = {"edits": []}
83
+ out['source'] = src
84
+ src_doc = annotator.parse(src)
85
+ tag_doc = annotator.parse(tag)
86
+ cur_edits = annotator.annotate(src_doc, tag_doc)
87
+
88
+
89
+ for e in cur_edits:
90
+ out["edits"].append((e.o_start, e.o_end, e.type, e.c_str))
91
+ result = []
92
+ last_pos = 0
93
+ if analyze:
94
+ tokens = out['source']
95
+ if isinstance(tokens, str):
96
+ tokens = tokens.split(' ')
97
+ edits = out['edits']
98
+ offset = 0
99
+ for edit in edits:
100
+ if isinstance(edit, dict):
101
+ e_start = edit['start']
102
+ e_end = edit['end']
103
+ e_type = edit['type']
104
+ e_rep = edit['cor']
105
+ elif isinstance(edit, tuple):
106
+ e_start = edit[0]
107
+ e_end = edit[1]
108
+ e_type = edit[2]
109
+ e_rep = edit[3]
110
+ else:
111
+ raise ValueError("Data type {} is not supported."\
112
+ .format(type(edit)))
113
+
114
+ e_rep = e_rep.strip()
115
+ op_type = e_type[0]
116
+ pos_type = e_type[2:]
117
+ errant_info = errant_verbose[pos_type]
118
+ title = errant_info["title"]
119
+
120
+ result.append((' '.join(tokens[last_pos:e_start + offset]), None))
121
+
122
+ ori_str = ' '.join(tokens[e_start + offset:e_end + offset]).strip()
123
+ if pos_type == "ORTH":
124
+ # check if it's a casing issue
125
+ if ori_str.lower() == e_rep.lower():
126
+ if e_rep[0].isupper() and ori_str[0].islower():
127
+ msg = "<b>{ori}</b> should be capitalized."
128
+ elif e_rep[0].islower() and ori_str[0].isupper():
129
+ msg = "<b>{ori}</b> should not be capitalized."
130
+ else:
131
+ msg = "The casing of the word <b>{ori}</b> is wrong."
132
+ # then it should be a spacing issue
133
+ else:
134
+ if len(ori_str) - 1 == len(e_rep):
135
+ msg = "The word <b>{ori}</b> should not be written separately."
136
+ elif len(ori_str) + 1 == len(e_rep):
137
+ msg = "The word <b>{ori}</b> should be separated into <b>{cor}</b>."
138
+ else:
139
+ msg = "The word <b>{ori}</b> has orthography error."
140
+ else:
141
+ if op_type in errant_info:
142
+ msg = errant_info[op_type]
143
+ else:
144
+ msg = errant_verbose["Default"][op_type]
145
+
146
+ msg = '<p>' + msg.format(ori=ori_str, cor=e_rep) + '</p>'
147
+
148
+ e_cor = e_rep.split()
149
+ len_cor = len(e_cor)
150
+ tokens[e_start + offset:e_end + offset] = e_cor
151
+ last_pos = e_start + offset + len_cor
152
+ offset = offset - (e_end - e_start) + len_cor
153
+ result.append((e_rep, pos_type))
154
+ out = ' '.join(tokens)
155
+ result.append((' '.join(tokens[last_pos:]), None))
156
+ print(result)
157
+ return result
158
+
159
+ def choices2promts() -> list:
160
+ """
161
+ Returns a list of available instructions for text editing.
162
+
163
+ Returns:
164
+ A list of instruction names.
165
+ """
166
+ return instruction_prompts.keys()
167
+
168
+ with gr.Blocks() as demo:
169
+
170
+ def turn_off_legend(msg: str) -> gr.update:
171
+ """
172
+ Turns off the legend in the highlighted text component.
173
+
174
+ Args:
175
+ msg: The text input.
176
+
177
+ Returns:
178
+ A Gradio update object to hide the legend.
179
+ """
180
+ return gr.update(show_legend=False)
181
+
182
+ def turn_on_legend(annotate: bool) -> gr.update:
183
+ """
184
+ Turns on the legend in the highlighted text component if annotate is True.
185
+
186
+ Args:
187
+ annotate: Whether to show annotations.
188
+
189
+ Returns:
190
+ A Gradio update object to show or hide the legend.
191
+ """
192
+ if annotate:
193
+ return gr.update(show_legend=True)
194
+ else:
195
+ return gr.update(show_legend=False)
196
+
197
+ def bot(task: str, text: str, post_check: bool, annotate: bool) -> tuple:
198
+ """
199
+ Processes the user input and returns the edited text along with annotations.
200
+
201
+ Args:
202
+ task: The chosen instruction for editing.
203
+ text: The text to be edited.
204
+ post_check: Whether to check for grammatical errors after text generation.
205
+ annotate: Whether to show annotations.
206
+
207
+ Yields:
208
+ Tuples of (edited text, annotation type) to update the interface.
209
+ """
210
+ response = ""
211
+ if task == "Grammar Error Correction":
212
+ yield [("Processing ...", None)], "Checking Grammar ..."
213
+ response = correcting_text(text)
214
+ else:
215
+ instruction = get_random_prompt(task)
216
+ prompt = instruction + ": " + text
217
+ print(prompt)
218
+ output = text_editor.create_chat_completion(
219
+ messages=[
220
+ {
221
+ "role": "system",
222
+ "content": "You are an English writing assistant, editing the text of user input and response based on user instructions. Please do not provide explanations, but respond only with the edited text. Also, if the instruction is not provided, correct the grammar of the text. Finally, if the instruction is not for editing text, correct the grammar of the text.",
223
+ },
224
+ {"role": "user", "content": f"{prompt}"},
225
+ ],
226
+ temperature=0.0,
227
+ stream=True,
228
+ )
229
+
230
+ response = ""
231
+ for chunk in output:
232
+ delta = chunk["choices"][0]["delta"]
233
+ if "role" in delta:
234
+ pass
235
+ elif "content" in delta:
236
+ response+=delta['content']
237
+ res = [(response, None), ]
238
+ print(res)
239
+ yield res, "Generating output ..."
240
+
241
+ if post_check:
242
+ yield [(response, None)], "Checking Grammar ..."
243
+ response = correcting_text(response)
244
+
245
+ print(response)
246
+
247
+ if annotate:
248
+ e_edit = annotate_text(text, response)
249
+ else:
250
+ e_edit = [(response, None)]
251
+
252
+ yield e_edit, "Done."
253
+
254
+ def handle_highlight_selection():
255
+ """
256
+ Handles the selection event of the highlighted text component.
257
+
258
+ This function is not implemented in the original code.
259
+ """
260
+ # print("hi")
261
+ return
262
+
263
+ gr.Markdown("# English Text Editing Application using T5 and Tiny Llama")
264
+ gr.Markdown("> source code: https://github.com/LETHEVIET/t5nyllama")
265
+ with gr.Row() as row:
266
+ with gr.Column(scale=1) as col1:
267
+ instruction = gr.Dropdown(
268
+ choices=choices2promts(),
269
+ value="Grammar Error Correction",
270
+ multiselect=False,
271
+ label="Choose your instruction",
272
+ interactive=True,
273
+ scale=0
274
+ )
275
+
276
+ with gr.Row() as row2:
277
+ clear = gr.Button("Clear", scale=-1)
278
+ submit = gr.Button("submit", scale=-1)
279
+
280
+ info_msg = gr.Textbox(
281
+ label="Information",
282
+ scale=1,
283
+ lines=3,
284
+ value="Therefore careful analysis of a product has to be made before select a solution for testing and implementation.",
285
+ )
286
+
287
+ post_check = gr.Checkbox(label="Check grammaticality after text generation.", value=True)
288
+ annotate = gr.Checkbox(label="Highlight different", value=True)
289
+ with gr.Column(scale=2) as col2:
290
+ msg = gr.Textbox(
291
+ label="Input",
292
+ scale=3,
293
+ value="Therefore careful analysis of a product has to be made before select a solution for testing and implementation.",
294
+ )
295
+
296
+ result = gr.HighlightedText(
297
+ label="Result",
298
+ combine_adjacent=True,
299
+ show_legend=False,
300
+ scale=3
301
+ )
302
+
303
+ res_msg = gr.Textbox(
304
+ scale=0,
305
+ visible=False,
306
+ label="Ouput",
307
+ )
308
+
309
+ msg.submit(turn_off_legend, msg, result).then(bot, [instruction, msg, post_check, annotate], [result, info_msg]).then(turn_on_legend, annotate, result)
310
+
311
+ clear.click(lambda: None, None, result, queue=False)
312
+
313
+ submit.click(turn_off_legend, msg, result).then(bot, [instruction, msg, post_check, annotate], [result, info_msg]).then(turn_on_legend, annotate, result)
314
+
315
+ result.select(handle_highlight_selection, [], [])
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch(server_port=7860)
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ BATCH_SIZE = 4
dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN apt-get update && apt-get -y upgrade
8
+ RUN apt-get install -y build-essential
9
+ RUN pip3 install --upgrade pip setuptools
10
+ RUN pip3 install -r requirements.txt
11
+ RUN python3 -m spacy download en_core_web_sm
12
+ RUN python3 download_dependencies.py
13
+
14
+ CMD ["python3", "app.py"]
download_dependencies.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import gdown
2
+ import nltk
3
+
4
+ id = "1TnPssg0CkWQ_thuAH8cY3hdB2J18A0Kl"
5
+ output = "texteditor-model/coedit-tinyllama-chat-bnb-4bit-unsloth.Q4_K_M.gguf"
6
+ gdown.download(id=id, output=output)
7
+
8
+ nltk.download('punkt')
errant_verbose.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Default": {
3
+ "M": "<b>{cor}</b> should be inserted here, considering the context.",
4
+ "R": "<b>{cor} is more appropriate than <b>{ori}</b> in this context.",
5
+ "U": "<b>{ori}</b> is unnecessary/incorrect in this context."
6
+ },
7
+ "ADJ": {
8
+ "title": "Adjective"
9
+ },
10
+ "ADJ:FORM": {
11
+ "title": "Adjective Form"
12
+ },
13
+ "ADV": {
14
+ "title": "Adverb"
15
+ },
16
+ "CONJ": {
17
+ "title": "Conjunction"
18
+ },
19
+ "CONTR": {
20
+ "title": "Contraction"
21
+ },
22
+ "DET": {
23
+ "title": "Determiner"
24
+ },
25
+ "NOUN": {
26
+ "title": "Noun"
27
+ },
28
+ "NOUN:POSS": {
29
+ "title": "Possessive Noun"
30
+ },
31
+ "OTHER": {
32
+ "title": ""
33
+ },
34
+ "PART": {
35
+ "title": "Particle"
36
+ },
37
+ "PREP": {
38
+ "title": "Preposition"
39
+ },
40
+ "PRON": {
41
+ "title": "Pronoun"
42
+ },
43
+ "PUNCT": {
44
+ "title": "Punctuation"
45
+ },
46
+ "VERB": {
47
+ "title": "Verb"
48
+ },
49
+ "VERB:FORM": {
50
+ "title": "Verb Form"
51
+ },
52
+ "VERB:TENSE": {
53
+ "title": "Verb Tense"
54
+ },
55
+ "MORPH": {
56
+ "title": "Morphology",
57
+ "R": "The word form of <b>{cor}</b> is more appropriate than <b>{ori}</b> here."
58
+ },
59
+ "NOUN:INFL": {
60
+ "title": "Noun Inflection",
61
+ "R": "<b>{ori}</b> is wrong and should be written as <b>{cor}</b>."
62
+ },
63
+ "NOUN:NUM": {
64
+ "title": "Noun Number",
65
+ "R": "The noun number of <b>{ori}</b> is wrong and should be written as <b>{cor}</b>."
66
+ },
67
+ "ORTH": {
68
+ "title": "Orthography"
69
+ },
70
+ "SPELL": {
71
+ "title": "Spelling",
72
+ "R": "<b>{ori}</b> is not the correct spelling of <b>{cor}</b>."
73
+ },
74
+ "VERB:INFL": {
75
+ "title": "Verb Inflection",
76
+ "R": "<b>{ori}</b> is wrong and should be written as <b>{cor}</b>."
77
+ },
78
+ "VERB:SVA": {
79
+ "title": "Subject-Verb Agreement",
80
+ "R": "The form of <b>{ori}</b> does not follow the subject-verb agreement."
81
+ },
82
+ "WO": {
83
+ "title": "Word Order",
84
+ "R": "The word order of '<b>{ori}</b>' is wrong."
85
+ }
86
+ }
finetuning_tinyllama.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ import torch
3
+
4
+ # Define model parameters
5
+ max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
6
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
7
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
8
+
9
+ # Load the model and tokenizer
10
+ model, tokenizer = FastLanguageModel.from_pretrained(
11
+ model_name="unsloth/tinyllama-chat-bnb-4bit", # "unsloth/tinyllama" for 16bit loading
12
+ max_seq_length=max_seq_length,
13
+ dtype=dtype,
14
+ load_in_4bit=load_in_4bit,
15
+ # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
16
+ )
17
+
18
+ # Apply PEFT (Parameter-Efficient Fine-Tuning)
19
+ model = FastLanguageModel.get_peft_model(
20
+ model,
21
+ r=32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
22
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
23
+ "gate_proj", "up_proj", "down_proj",],
24
+ lora_alpha=32,
25
+ lora_dropout=0, # Currently only supports dropout = 0
26
+ bias="none", # Currently only supports bias = "none"
27
+ use_gradient_checkpointing=False, # @@@ IF YOU GET OUT OF MEMORY - set to True @@@
28
+ random_state=3407,
29
+ use_rslora=False, # We support rank stabilized LoRA
30
+ loftq_config=None, # And LoftQ
31
+ )
32
+
33
+ # Data preparation
34
+ import pandas as pd
35
+ from sklearn.model_selection import train_test_split
36
+ import datasets
37
+
38
+ # Load the dataset
39
+ train = datasets.load_dataset("grammarly/coedit", split="train").to_pandas()
40
+ val = datasets.load_dataset("grammarly/coedit", split="validation").to_pandas()
41
+
42
+ # Data cleaning and preparation
43
+ data = pd.concat([train, val])
44
+ data[['instruction', 'input']] = data['src'].str.split(': ', n=1, expand=True)
45
+ data = data.rename(columns={"tgt": "output"})
46
+ data = data.drop(columns=["_id", "src"])
47
+
48
+ # Stratify based on task for balanced splits
49
+ stratify_col = data['task']
50
+
51
+ # Split the data into train and test sets
52
+ train_df, test_df = train_test_split(
53
+ data,
54
+ test_size=0.2,
55
+ random_state=42,
56
+ stratify=stratify_col
57
+ )
58
+
59
+ def formatting_prompts_func(examples, tokenizer):
60
+ """
61
+ Formats the examples into the desired chat format for training.
62
+
63
+ Args:
64
+ examples: A dictionary of examples from the dataset.
65
+ tokenizer: The tokenizer used for processing text.
66
+
67
+ Returns:
68
+ A dictionary containing the formatted text for each example.
69
+ """
70
+ instructions = examples["instruction"]
71
+ inputs = examples["input"]
72
+ outputs = examples["output"]
73
+ texts = []
74
+ for instruction, input, output in zip(instructions, inputs, outputs):
75
+ message = [
76
+ {"role": "user", "content": instruction + ": " + input},
77
+ {"role": "assistant", "content": output},
78
+ ]
79
+ text = tokenizer.apply_chat_template(
80
+ message, tokenize=False, add_generation_prompt=False)
81
+ texts.append(text)
82
+ return {"text": texts, }
83
+
84
+ # Create datasets from pandas DataFrames
85
+ train_ds = datasets.Dataset.from_pandas(train_df)
86
+ test_ds = datasets.Dataset.from_pandas(test_df)
87
+
88
+ # Map the formatting function to the datasets for chat format conversion
89
+ train_ds = train_ds.map(formatting_prompts_func, fn_kwargs={"tokenizer": tokenizer}, batched=True,)
90
+ test_ds = test_ds.map(formatting_prompts_func, fn_kwargs={"tokenizer": tokenizer}, batched=True,)
91
+
92
+ print(train_ds[0]['text'])
93
+
94
+ # Fine-tuning with trl
95
+ from trl import SFTTrainer
96
+ from transformers import TrainingArguments
97
+
98
+ # Define training arguments
99
+ trainer = SFTTrainer(
100
+ model=model,
101
+ tokenizer=tokenizer,
102
+ train_dataset=train_ds,
103
+ eval_dataset=test_ds,
104
+ dataset_text_field="text",
105
+ max_seq_length=max_seq_length,
106
+ dataset_num_proc=10,
107
+ packing=False, # Can make training 5x faster for short sequences.
108
+ args=TrainingArguments(
109
+ per_device_train_batch_size=8,
110
+ per_device_eval_batch_size=8,
111
+ gradient_accumulation_steps=4,
112
+ warmup_steps=5,
113
+ num_train_epochs=2,
114
+ learning_rate=2e-4,
115
+ fp16=not torch.cuda.is_bf16_supported(),
116
+ bf16=torch.cuda.is_bf16_supported(),
117
+ logging_steps=1,
118
+ save_steps=100,
119
+ save_total_limit=4, # Limit the total number of checkpoints
120
+ evaluation_strategy="steps",
121
+ eval_steps=100,
122
+ optim="adamw_8bit",
123
+ weight_decay=0.01,
124
+ lr_scheduler_type="linear",
125
+ seed=3407,
126
+ output_dir="outputs",
127
+ load_best_model_at_end=True,
128
+ save_strategy="steps",
129
+ ),
130
+ )
131
+
132
+ # Print GPU information
133
+ gpu_stats = torch.cuda.get_device_properties(0)
134
+ start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
135
+ max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
136
+ print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
137
+ print(f"{start_gpu_memory} GB of memory reserved.")
138
+
139
+ # Train the model
140
+ trainer_stats = trainer.train()
141
+
142
+ # Print memory usage statistics
143
+ used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
144
+ used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
145
+ used_percentage = round(used_memory / max_memory * 100, 3)
146
+ lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
147
+ print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
148
+ print(f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training.")
149
+ print(f"Peak reserved memory = {used_memory} GB.")
150
+ print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
151
+ print(f"Peak reserved memory % of max memory = {used_percentage} %.")
152
+ print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
153
+
154
+ # Save the trained model and tokenizer
155
+ print("Saving model to local")
156
+ model.save_pretrained("coedit-tinyllama-chat-bnb-4bit") # Local saving
157
+ tokenizer.save_pretrained("coedit-tinyllama-chat-bnb-4bit")
158
+
159
+ # Evaluate the model (Optional)
160
+ #trainer.evaluate()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ spacy
2
+ transformers
3
+ gradio
4
+ errant
5
+ nltk
6
+ llama-cpp-python
7
+ gdown
8
+ tensorflow
9
+ tf-keras
utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GEC = [
2
+ "Fix grammar",
3
+ "Fix grammar in this sentence",
4
+ "Fix grammar in the sentence",
5
+ "Fix grammar errors",
6
+ "Fix grammatical errors",
7
+ "Fix grammaticality",
8
+ "Fix all grammatical errors",
9
+ "Fix grammatical errors in this sentence",
10
+ "Fix grammar errors in this sentence",
11
+ "Fix grammatical mistakes in this sentence",
12
+ "Fix grammaticality in this sentence",
13
+ "Fix grammaticality of the sentence",
14
+ "Fix disfluencies in the sentence",
15
+ "Make the sentence grammatical",
16
+ "Make the sentence fluent",
17
+ "Fix errors in this text",
18
+ "Update to remove grammar errors",
19
+ "Remove all grammatical errors from this text",
20
+ "Improve the grammar of this text",
21
+ "Improve the grammaticality",
22
+ "Improve the grammaticality of this text",
23
+ "Improve the grammaticality of this sentence",
24
+ "Grammar improvements",
25
+ "Remove grammar mistakes",
26
+ "Remove grammatical mistakes",
27
+ "Fix the grammar mistakes",
28
+ "Fix grammatical mistakes Clarity Clarify the sentence",
29
+ ]
30
+ Clarify = [
31
+ "Clarify this sentence",
32
+ "Clarify this text",
33
+ "Write a clearer version for the sentence",
34
+ "Write a clarified version of the sentence",
35
+ "Write a readable version of the sentence",
36
+ "Write a better readable version of the sentence",
37
+ "Rewrite the sentence more clearly",
38
+ "Rewrite this sentence clearly",
39
+ "Rewrite this sentence for clarity",
40
+ "Rewrite this sentence for readability",
41
+ "Improve this sentence for readability",
42
+ "Make this sentence better readable",
43
+ "Make this sentence more readable",
44
+ "Make this sentence readable",
45
+ "Make the sentence clear",
46
+ "Make the sentence clearer",
47
+ "Clarify",
48
+ "Make the text more understandable",
49
+ "Make this easier to read",
50
+ "Clarification",
51
+ "Change to clearer wording",
52
+ "Clarify this paragraph",
53
+ "Use clearer wording Simplification Simplify the sentence",
54
+ "Simplify this sentence",
55
+ "Simplify this text",
56
+ "Write a simpler version for the sentence",
57
+ "Rewrite the sentence to be simpler",
58
+ "Rewrite this sentence in a simpler manner",
59
+ "Rewrite this sentence for simplicity",
60
+ "Rewrite this with simpler wording",
61
+ "Make the sentence simple",
62
+ "Make the sentence simpler",
63
+ "Make this text less complex",
64
+ "Make this simpler",
65
+ "Simplify",
66
+ "Simplification",
67
+ "Change to simpler wording",
68
+ "Simplify this paragraph",
69
+ "Simplify this text",
70
+ "Use simpler wording",
71
+ "Make this easier to understand"
72
+ ]
73
+ Coherence = [
74
+ "Fix coherence",
75
+ "Fix coherence in this sentence",
76
+ "Fix coherence in the sentence",
77
+ "Fix coherence in this text",
78
+ "Fix coherence in the text",
79
+ "Fix coherence errors",
80
+ "Fix sentence flow",
81
+ "Fix sentence transition",
82
+ "Fix coherence errors in this sentence",
83
+ "Fix coherence mistakes in this sentence",
84
+ "Fix coherence in this sentence",
85
+ "Fix coherence of the sentence",
86
+ "Fix lack of coherence in the sentence",
87
+ "Make the text more coherent",
88
+ "Make the text coherent",
89
+ "Make the text more cohesive",
90
+ "logically linked and consistent as a whole",
91
+ "Make the text more cohesive",
92
+ "Improve the cohesiveness of the text",
93
+ "Make the text more logical",
94
+ "Make the text more consistent",
95
+ "Improve the consistency of the text",
96
+ "Make the text clearer",
97
+ "Improve the coherence of the text"
98
+ ]
99
+ Formality_Style_Transfer = [
100
+ "Formalize",
101
+ "Improve formality",
102
+ "Formalize the sentence",
103
+ "Formalize this sentence",
104
+ "Formalize the text",
105
+ "Formalize this text",
106
+ "Make this formal",
107
+ "Make this more formal",
108
+ "Make this sound more formal",
109
+ "Make the sentence formal",
110
+ "Make the sentence more formal",
111
+ "Make the sentence sound more formal",
112
+ "Write more formally",
113
+ "Write less informally",
114
+ "Rewrite more formally",
115
+ "Write this more formally",
116
+ "Rewrite this more formally",
117
+ "Write in a formal manner",
118
+ "Write in a more formal manner",
119
+ "Rewrite in a more formal manner"
120
+ ]
121
+ Neutralization = [
122
+ "Remove POV",
123
+ "Remove POVs",
124
+ "Remove POV in this text",
125
+ "Remove POVs in this text",
126
+ "Neutralize this text",
127
+ "Neutralize the text",
128
+ "Neutralize this sentence",
129
+ "Neutralize the sentence",
130
+ "Make this more neutral",
131
+ "Make this text more neutral",
132
+ "Make this sentence more neutral",
133
+ "Make this paragraph more neutral",
134
+ "Remove unsourced opinions",
135
+ "Remove unsourced opinions from this text",
136
+ "Remove non-neutral POVs",
137
+ "Remove non-neutral POV",
138
+ "Remove non-neutral points of view",
139
+ "Remove points of view",
140
+ "Make this text less biased Paraphrasing Paraphrase the sentence",
141
+ "Paraphrase this sentence",
142
+ "Paraphrase this text",
143
+ ]
144
+ Paraphrase = [
145
+ "Write a paraphrase for the sentence",
146
+ "Write a paraphrased version of the sentence",
147
+ "Rewrite the sentence with different wording",
148
+ "Use different wording",
149
+ "Rewrite this sentence",
150
+ "Reword this sentence",
151
+ "Rephrase this sentence",
152
+ "Rewrite this text",
153
+ "Reword this text",
154
+ "Rephrase this text"
155
+ ]
156
+
157
+
158
+ import random
159
+ import os
160
+
161
+ instruction_prompts = {
162
+ "Grammar Error Correction": GEC,
163
+ "Clarify": Clarify,
164
+ "Coherence": Coherence,
165
+ "Formality Style Transfer": Formality_Style_Transfer,
166
+ "Neutralization": Neutralization,
167
+ "Paraphrase": Paraphrase,
168
+ }
169
+
170
+ def get_prompt_list(instruction_type: str) -> list:
171
+ """
172
+ Returns a list of prompts for the given instruction type.
173
+
174
+ Args:
175
+ instruction_type: The type of instruction, e.g., "Grammar Error Correction".
176
+
177
+ Returns:
178
+ A list of prompts corresponding to the instruction type.
179
+ """
180
+ return instruction_prompts[instruction_type]
181
+
182
+ def get_random_prompt(instruction_type: str) -> str:
183
+ """
184
+ Returns a random prompt from the list of prompts for the given instruction type.
185
+
186
+ Args:
187
+ instruction_type: The type of instruction, e.g., "Grammar Error Correction".
188
+
189
+ Returns:
190
+ A random prompt from the list of prompts for the instruction type.
191
+ """
192
+ return random.choice(instruction_prompts[instruction_type])
193
+