Spaces:
Sleeping
Sleeping
sergey-hovhannisyan
commited on
Commit
•
19df55a
1
Parent(s):
4838aca
final touch, finished model loading
Browse files- .gitignore +4 -0
- app.py +5 -7
- models/fine_tuned/.gitkeep +0 -0
- src/evaluate.py +37 -4
- src/finetune.ipynb +215 -137
.gitignore
CHANGED
@@ -127,3 +127,7 @@ dmypy.json
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
130 |
+
|
131 |
+
# Model training logs, checkpoints, etc.
|
132 |
+
tokenized_encodings/
|
133 |
+
models/
|
app.py
CHANGED
@@ -3,23 +3,21 @@ from src.evaluate import evaluate_prompt, model_list
|
|
3 |
|
4 |
st.title("Toxic Tweets")
|
5 |
|
|
|
6 |
st.info("This NLP machine learning project aims to predict the toxicity level of input tweets using fine-tuning techniques on a pre-trained language model. The project utilizes Docker containers for efficient deployment and management, while Hugging Face Spaces and Transformers provide the necessary libraries and tools for building and training the model. The model is trained on a large dataset of labeled toxic tweets to enable it to classify new input tweets as toxic or non-toxic. This project can help improve online safety by automatically flagging potentially harmful content.")
|
7 |
|
8 |
# variables defined
|
9 |
sentiment_model_names = model_list()
|
10 |
section1, section2 = st.columns(2)
|
11 |
|
12 |
-
#
|
13 |
def predict(model_name, prompt):
|
14 |
-
output = evaluate_prompt(model_name, prompt)
|
15 |
-
label = output["label"]
|
16 |
-
score = output["score"]
|
17 |
with section2:
|
18 |
-
st.
|
19 |
-
st.write("Score:", score)
|
20 |
st.success("Completed!")
|
21 |
|
22 |
-
|
23 |
with section1:
|
24 |
st.header("Input")
|
25 |
prompt = st.text_area("Prompt", "Eat Up Every Moment")
|
|
|
3 |
|
4 |
st.title("Toxic Tweets")
|
5 |
|
6 |
+
# description of the project
|
7 |
st.info("This NLP machine learning project aims to predict the toxicity level of input tweets using fine-tuning techniques on a pre-trained language model. The project utilizes Docker containers for efficient deployment and management, while Hugging Face Spaces and Transformers provide the necessary libraries and tools for building and training the model. The model is trained on a large dataset of labeled toxic tweets to enable it to classify new input tweets as toxic or non-toxic. This project can help improve online safety by automatically flagging potentially harmful content.")
|
8 |
|
9 |
# variables defined
|
10 |
sentiment_model_names = model_list()
|
11 |
section1, section2 = st.columns(2)
|
12 |
|
13 |
+
# function to predict the output
|
14 |
def predict(model_name, prompt):
|
15 |
+
output = evaluate_prompt(model_name, prompt)
|
|
|
|
|
16 |
with section2:
|
17 |
+
st.table(output)
|
|
|
18 |
st.success("Completed!")
|
19 |
|
20 |
+
# main code
|
21 |
with section1:
|
22 |
st.header("Input")
|
23 |
prompt = st.text_area("Prompt", "Eat Up Every Moment")
|
models/fine_tuned/.gitkeep
DELETED
File without changes
|
src/evaluate.py
CHANGED
@@ -1,14 +1,47 @@
|
|
1 |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
def evaluate_prompt(model_name, prompt):
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
return
|
9 |
|
|
|
10 |
def model_list():
|
11 |
sentiment_models = [
|
|
|
12 |
"distilbert-base-uncased-finetuned-sst-2-english",
|
13 |
"bert-base-uncased",
|
14 |
"albert-base-v2"]
|
|
|
1 |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
|
2 |
+
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
|
7 |
+
# Public function called from app.py to evaluate the prompt
|
8 |
def evaluate_prompt(model_name, prompt):
|
9 |
+
# Check if the model is fine-tuned-toxic-tweets
|
10 |
+
if model_name == "fine-tuned-toxic-tweets":
|
11 |
+
return eval_fine_tuned_toxic_tweets(model_name, prompt)
|
12 |
+
else: # If not, use the pipeline function
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
15 |
+
classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
|
16 |
+
return classifier(prompt)
|
17 |
+
|
18 |
+
# Private function to evaluate the prompt using the fine-tuned-toxic-tweets model
|
19 |
+
def eval_fine_tuned_toxic_tweets(model_name, prompt):
|
20 |
+
# Load the model and tokenizer
|
21 |
+
model = DistilBertForSequenceClassification.from_pretrained("sergey-hovhannisyan/fine-tuned-toxic-tweets")
|
22 |
+
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
|
23 |
+
encoded_text = tokenizer(prompt, truncation=True, padding=True, return_tensors='pt')
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
output = model(**encoded_text)
|
27 |
+
|
28 |
+
# Get the labels and scores
|
29 |
+
labels = np.array(["toxic", "severe", "obscene", "threat", "insult", "identity hate"])
|
30 |
+
scores = torch.sigmoid(output.logits)*100
|
31 |
+
scores = scores.numpy()
|
32 |
+
|
33 |
+
# Sort the scores in descending order
|
34 |
+
sort_idx = np.flip(np.argsort(scores))
|
35 |
+
labels = labels[sort_idx]
|
36 |
+
scores = scores[0][sort_idx]
|
37 |
|
38 |
+
# Return the labels and scores as a dataframe
|
39 |
+
return pd.DataFrame({"Label": labels[0].tolist(), "Score": scores[0].tolist()}).set_index("Label")
|
40 |
|
41 |
+
# List of models to choose from
|
42 |
def model_list():
|
43 |
sentiment_models = [
|
44 |
+
"fine-tuned-toxic-tweets",
|
45 |
"distilbert-base-uncased-finetuned-sst-2-english",
|
46 |
"bert-base-uncased",
|
47 |
"albert-base-v2"]
|
src/finetune.ipynb
CHANGED
@@ -1,140 +1,218 @@
|
|
1 |
{
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
},
|
11 |
-
|
12 |
-
|
13 |
-
"execution_count": null,
|
14 |
-
"metadata": {},
|
15 |
-
"outputs": [],
|
16 |
-
"source": [
|
17 |
-
"import pandas as pd\n",
|
18 |
-
"import torch\n",
|
19 |
-
"from torch.utils.data import Dataset \n",
|
20 |
-
"from sklearn.model_selection import train_test_split\n",
|
21 |
-
"from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
|
22 |
-
"from transformers import Trainer, TrainingArguments"
|
23 |
-
]
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"cell_type": "code",
|
27 |
-
"execution_count": null,
|
28 |
-
"metadata": {},
|
29 |
-
"outputs": [],
|
30 |
-
"source": [
|
31 |
-
"# Getting training dataset\n",
|
32 |
-
"df = pd.read_csv(\"../data/raw/train.csv\")\n",
|
33 |
-
"# Comments as list of strings for training texts\n",
|
34 |
-
"texts = df[\"comment_text\"].tolist()\n",
|
35 |
-
"# Labels extracted from dataframe as list of lists\n",
|
36 |
-
"labels = df[[\"toxic\",\"severe_toxic\",\"obscene\",\"threat\",\"insult\",\"identity_hate\"]].values.tolist()\n",
|
37 |
-
"# Training set split into training and validation sets\n",
|
38 |
-
"train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)"
|
39 |
-
]
|
40 |
-
},
|
41 |
-
{
|
42 |
-
"cell_type": "code",
|
43 |
-
"execution_count": null,
|
44 |
-
"metadata": {},
|
45 |
-
"outputs": [],
|
46 |
-
"source": [
|
47 |
-
"# Tokenizing training and validation sets\n",
|
48 |
-
"class ToxicTweetsDataset(Dataset):\n",
|
49 |
-
" def __init__(self, encodings, labels):\n",
|
50 |
-
" self.encodings = encodings\n",
|
51 |
-
" self.labels = labels\n",
|
52 |
-
"\n",
|
53 |
-
" def __len__(self):\n",
|
54 |
-
" return len(self.encodings['input_ids'])\n",
|
55 |
-
"\n",
|
56 |
-
" def __getitem__(self, index):\n",
|
57 |
-
" input_ids = self.encodings['input_ids'][index]\n",
|
58 |
-
" attention_mask = self.encodings['attention_mask'][index]\n",
|
59 |
-
" labels = torch.tensor(self.labels[index], dtype=torch.float32)\n",
|
60 |
-
"\n",
|
61 |
-
" return {'input_ids': input_ids,\n",
|
62 |
-
" 'attention_mask': attention_mask,\n",
|
63 |
-
" 'labels': labels}\n"
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"cell_type": "code",
|
68 |
-
"execution_count": null,
|
69 |
-
"metadata": {},
|
70 |
-
"outputs": [],
|
71 |
-
"source": [
|
72 |
-
"# Choosing base pretrained model\n",
|
73 |
-
"model_name = \"distilbert-base-uncased\"\n",
|
74 |
-
"tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)\n",
|
75 |
-
"\n",
|
76 |
-
"# Tokenizing and encoding training and validation sets\n",
|
77 |
-
"train_encodings = tokenizer(train_texts, truncation=True, padding=True)\n",
|
78 |
-
"val_encodings = tokenizer(val_texts, truncation=True, padding=True)\n",
|
79 |
-
"\n",
|
80 |
-
"# Creating datasets for training and validation\n",
|
81 |
-
"train_dataset = ToxicTweetsDataset(train_encodings, train_labels)\n",
|
82 |
-
"val_dataset = ToxicTweetsDataset(val_encodings, val_labels)\n",
|
83 |
-
"\n",
|
84 |
-
"# Setting training arguments\n",
|
85 |
-
"training_args = TrainingArguments(\n",
|
86 |
-
" output_dir=\"../models/fine_tuned\",\n",
|
87 |
-
" num_train_epochs=2, \n",
|
88 |
-
" per_device_train_batch_size=16, \n",
|
89 |
-
" per_device_eval_batch_size=64, \n",
|
90 |
-
" warmup_steps=500,\n",
|
91 |
-
" learning_rate=5e-5,\n",
|
92 |
-
" weight_decay=0.01,\n",
|
93 |
-
" logging_dir=\"./logs\",\n",
|
94 |
-
" logging_steps=10,\n",
|
95 |
-
" save_strategy=\"epoch\",\n",
|
96 |
-
" save_total_limit=1\n",
|
97 |
-
")\n",
|
98 |
-
"\n",
|
99 |
-
"# If GPU is available, use it, otherwise use CPU\n",
|
100 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
101 |
-
"\n",
|
102 |
-
"# Creating model\n",
|
103 |
-
"model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
|
104 |
-
"\n",
|
105 |
-
"# Creating trainer\n",
|
106 |
-
"trainer = Trainer(\n",
|
107 |
-
" model=model, \n",
|
108 |
-
" args=training_args,\n",
|
109 |
-
" train_dataset=train_dataset,\n",
|
110 |
-
" eval_dataset=val_dataset\n",
|
111 |
-
")\n",
|
112 |
-
"\n",
|
113 |
-
"# Training model\n",
|
114 |
-
"trainer.train()"
|
115 |
-
]
|
116 |
-
}
|
117 |
-
],
|
118 |
-
"metadata": {
|
119 |
-
"kernelspec": {
|
120 |
-
"display_name": "Python 3",
|
121 |
-
"language": "python",
|
122 |
-
"name": "python3"
|
123 |
-
},
|
124 |
-
"language_info": {
|
125 |
-
"codemirror_mode": {
|
126 |
-
"name": "ipython",
|
127 |
-
"version": 3
|
128 |
-
},
|
129 |
-
"file_extension": ".py",
|
130 |
-
"mimetype": "text/x-python",
|
131 |
-
"name": "python",
|
132 |
-
"nbconvert_exporter": "python",
|
133 |
-
"pygments_lexer": "ipython3",
|
134 |
-
"version": "3.11.2"
|
135 |
-
},
|
136 |
-
"orig_nbformat": 4
|
137 |
-
},
|
138 |
-
"nbformat": 4,
|
139 |
-
"nbformat_minor": 2
|
140 |
}
|
|
|
1 |
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "pgTOocy70xHB"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"## ___Toxic Tweets Fine-Tuned Pretrained Transformer with Multi Head Classifier___"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "uGpWnNct0xHC"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import pandas as pd\n",
|
21 |
+
"import numpy as np\n",
|
22 |
+
"import torch\n",
|
23 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
24 |
+
"from sklearn.model_selection import train_test_split\n",
|
25 |
+
"from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
|
26 |
+
"from transformers import Trainer, TrainingArguments"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {
|
33 |
+
"id": "6JovL-hI8JKQ"
|
34 |
+
},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"# If GPU is available, use it, otherwise use CPU\n",
|
38 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {
|
45 |
+
"id": "DBhZv3Sc0xHD"
|
46 |
+
},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"# Getting training dataset\n",
|
50 |
+
"df = pd.read_csv(\"../data/raw/train.csv\")\n",
|
51 |
+
"# Comments as list of strings for training texts\n",
|
52 |
+
"texts = df[\"comment_text\"].tolist()\n",
|
53 |
+
"# Labels extracted from dataframe as list of lists\n",
|
54 |
+
"labels = df[[\"toxic\",\"severe_toxic\",\"obscene\",\"threat\",\"insult\",\"identity_hate\"]].values.tolist()\n",
|
55 |
+
"# Training set split into training and validation sets\n",
|
56 |
+
"train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"metadata": {
|
63 |
+
"id": "Q0OqpFmZ0xHD"
|
64 |
+
},
|
65 |
+
"outputs": [],
|
66 |
+
"source": [
|
67 |
+
"# Tokenizing training and validation sets\n",
|
68 |
+
"class ToxicTweetsDataset(Dataset):\n",
|
69 |
+
" # Initialize the class variables\n",
|
70 |
+
" def __init__(self, encodings, labels):\n",
|
71 |
+
" self.encodings = encodings\n",
|
72 |
+
" self.labels = labels\n",
|
73 |
+
" # Returns the length of the dataset\n",
|
74 |
+
" def __len__(self):\n",
|
75 |
+
" return len(self.encodings['input_ids'])\n",
|
76 |
+
" # Returns a dictionary of the tokenized text, attention mask, and labels\n",
|
77 |
+
" def __getitem__(self, index):\n",
|
78 |
+
" input_ids = self.encodings['input_ids'][index]\n",
|
79 |
+
" attention_mask = self.encodings['attention_mask'][index]\n",
|
80 |
+
" labels = torch.tensor(self.labels[index], dtype=torch.float32)\n",
|
81 |
+
" return {'input_ids': input_ids,\n",
|
82 |
+
" 'attention_mask': attention_mask,\n",
|
83 |
+
" 'labels': labels}"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": null,
|
89 |
+
"metadata": {
|
90 |
+
"id": "zRQlGBz3Dwbs"
|
91 |
+
},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"# Choosing base pretrained model\n",
|
95 |
+
"model_name = \"distilbert-base-uncased\"\n",
|
96 |
+
"tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {
|
103 |
+
"id": "AEtsC8-Y0xHE"
|
104 |
+
},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"# Tokenizing and encoding training and validation sets\n",
|
108 |
+
"train_encodings = tokenizer.batch_encode_plus(train_texts, truncation=True, padding=True, return_tensors='pt')\n",
|
109 |
+
"val_encodings = tokenizer.batch_encode_plus(val_texts, truncation=True, padding=True, return_tensors='pt')"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": null,
|
115 |
+
"metadata": {
|
116 |
+
"id": "bjQ_ZuJxLnb8"
|
117 |
+
},
|
118 |
+
"outputs": [],
|
119 |
+
"source": [
|
120 |
+
"# Saving encoded and tokenized data to files\n",
|
121 |
+
"torch.save(train_encodings, '../data/tokenized_encodings/train_encodings.pt')\n",
|
122 |
+
"torch.save(val_encodings, '../data/tokenized_encodings/val_encodings.pt')"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"metadata": {
|
129 |
+
"id": "lqmHgIXwNMJQ"
|
130 |
+
},
|
131 |
+
"outputs": [],
|
132 |
+
"source": [
|
133 |
+
"# Creating training and validation datasets\n",
|
134 |
+
"train_encodings = torch.load('../data/tokenized_encodings/train_encodings.pt').to(device)\n",
|
135 |
+
"val_encodings = torch.load('../data/tokenized_encodings/val_encodings.pt').to(device)"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": null,
|
141 |
+
"metadata": {
|
142 |
+
"id": "FmEatZLqDt6p"
|
143 |
+
},
|
144 |
+
"outputs": [],
|
145 |
+
"source": [
|
146 |
+
"# Creating datasets for training and validation\n",
|
147 |
+
"train_dataset = ToxicTweetsDataset(train_encodings, train_labels)\n",
|
148 |
+
"val_dataset = ToxicTweetsDataset(val_encodings, val_labels)\n",
|
149 |
+
"\n",
|
150 |
+
"# Creating model\n",
|
151 |
+
"model = DistilBertForSequenceClassification.from_pretrained(model_name, problem_type=\"multi_label_classification\", num_labels=6)\n",
|
152 |
+
"model.to(device)"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": null,
|
158 |
+
"metadata": {
|
159 |
+
"id": "8CIjI_Q75haw"
|
160 |
+
},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"# Setting training arguments\n",
|
164 |
+
"training_args = TrainingArguments(\n",
|
165 |
+
" output_dir=\"../models/fine_tuned\",\n",
|
166 |
+
" num_train_epochs=2, \n",
|
167 |
+
" per_device_train_batch_size=16, \n",
|
168 |
+
" per_device_eval_batch_size=32, \n",
|
169 |
+
" warmup_steps=500,\n",
|
170 |
+
" learning_rate=5e-5,\n",
|
171 |
+
" weight_decay=0.01,\n",
|
172 |
+
" logging_dir=\"./logs\",\n",
|
173 |
+
" logging_steps=10,\n",
|
174 |
+
" save_strategy=\"epoch\",\n",
|
175 |
+
" save_total_limit=1,\n",
|
176 |
+
")\n",
|
177 |
+
"\n",
|
178 |
+
"# Creating trainer\n",
|
179 |
+
"trainer = Trainer(\n",
|
180 |
+
" model=model,\n",
|
181 |
+
" args=training_args,\n",
|
182 |
+
" train_dataset=train_dataset,\n",
|
183 |
+
" eval_dataset=val_dataset,\n",
|
184 |
+
")\n",
|
185 |
+
"\n",
|
186 |
+
"# Training model\n",
|
187 |
+
"trainer.train()"
|
188 |
+
]
|
189 |
+
}
|
190 |
+
],
|
191 |
+
"metadata": {
|
192 |
+
"accelerator": "GPU",
|
193 |
+
"colab": {
|
194 |
+
"provenance": []
|
195 |
+
},
|
196 |
+
"gpuClass": "standard",
|
197 |
+
"kernelspec": {
|
198 |
+
"display_name": "Python 3",
|
199 |
+
"language": "python",
|
200 |
+
"name": "python3"
|
201 |
+
},
|
202 |
+
"language_info": {
|
203 |
+
"codemirror_mode": {
|
204 |
+
"name": "ipython",
|
205 |
+
"version": 3
|
206 |
+
},
|
207 |
+
"file_extension": ".py",
|
208 |
+
"mimetype": "text/x-python",
|
209 |
+
"name": "python",
|
210 |
+
"nbconvert_exporter": "python",
|
211 |
+
"pygments_lexer": "ipython3",
|
212 |
+
"version": "3.11.2"
|
213 |
+
},
|
214 |
+
"orig_nbformat": 4
|
215 |
},
|
216 |
+
"nbformat": 4,
|
217 |
+
"nbformat_minor": 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
}
|