nicoleathy commited on
Commit
d071d21
·
verified ·
1 Parent(s): 387046f

Delete competition/Gemma-2-9b.ipynb

Browse files
Files changed (1) hide show
  1. competition/Gemma-2-9b.ipynb +0 -132
competition/Gemma-2-9b.ipynb DELETED
@@ -1,132 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer\n",
10
- "from datasets import Dataset\n",
11
- "import pandas as pd\n",
12
- "from sklearn.model_selection import train_test_split\n",
13
- "\n",
14
- "# Load the dataset\n",
15
- "file_path = 'train_en.csv'\n",
16
- "dataset = pd.read_csv(file_path)\n",
17
- "\n",
18
- "# Map labels to expected responses\n",
19
- "label_mapping = {\n",
20
- " \"Yes\": 0,\n",
21
- " \"No\": 1,\n",
22
- " \"It doesn't matter\": 2,\n",
23
- " \"Unimportant\": 2, # Assuming \"unimportant\" is synonymous with \"It doesn't matter\"\n",
24
- " \"Incorrect questioning\": 3,\n",
25
- " \"Correct answers\": 4\n",
26
- "}\n",
27
- "\n",
28
- "# Apply label mapping\n",
29
- "dataset['label'] = dataset['label'].map(label_mapping)\n",
30
- "\n",
31
- "# Handle NaN values: Drop rows where label is NaN\n",
32
- "dataset = dataset.dropna(subset=['label'])\n",
33
- "\n",
34
- "# Ensure labels are integers\n",
35
- "dataset['label'] = dataset['label'].astype(int)\n",
36
- "\n",
37
- "# Split the dataset into training and validation sets\n",
38
- "train_df, val_df = train_test_split(dataset, test_size=0.2, random_state=42)\n",
39
- "\n",
40
- "# Convert the dataframes to datasets\n",
41
- "train_dataset = Dataset.from_pandas(train_df)\n",
42
- "val_dataset = Dataset.from_pandas(val_df)\n",
43
- "\n",
44
- "# Load the tokenizer and model\n",
45
- "model_name = \"google/gemma-2-9b\"\n",
46
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
47
- "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=5)\n",
48
- "\n",
49
- "# Tokenize the data\n",
50
- "def tokenize_function(examples):\n",
51
- " return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)\n",
52
- "\n",
53
- "train_dataset = train_dataset.map(tokenize_function, batched=True)\n",
54
- "val_dataset = val_dataset.map(tokenize_function, batched=True)\n",
55
- "\n",
56
- "# Set the format for PyTorch\n",
57
- "train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\n",
58
- "val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\n",
59
- "\n",
60
- "# Define training arguments\n",
61
- "training_args = TrainingArguments(\n",
62
- " output_dir='./results',\n",
63
- " evaluation_strategy='epoch',\n",
64
- " learning_rate=2e-5,\n",
65
- " per_device_train_batch_size=8,\n",
66
- " per_device_eval_batch_size=8,\n",
67
- " num_train_epochs=3,\n",
68
- " weight_decay=0.01,\n",
69
- ")\n",
70
- "\n",
71
- "# Initialize the Trainer\n",
72
- "trainer = Trainer(\n",
73
- " model=model,\n",
74
- " args=training_args,\n",
75
- " train_dataset=train_dataset,\n",
76
- " eval_dataset=val_dataset,\n",
77
- ")\n",
78
- "\n",
79
- "# Train the model\n",
80
- "trainer.train()\n",
81
- "\n",
82
- "# Save the model\n",
83
- "model.save_pretrained('trained_gemma_model')\n",
84
- "tokenizer.save_pretrained('trained_gemma_model')\n",
85
- "\n",
86
- "# Evaluate the model\n",
87
- "trainer.evaluate()"
88
- ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": null,
93
- "metadata": {},
94
- "outputs": [],
95
- "source": [
96
- "# Load the trained model and tokenizer\n",
97
- "model = AutoModelForSequenceClassification.from_pretrained('trained_gemma_model')\n",
98
- "tokenizer = AutoTokenizer.from_pretrained('trained_gemma_model')\n",
99
- "\n",
100
- "# Function to make predictions\n",
101
- "def predict(texts):\n",
102
- " inputs = tokenizer(texts, return_tensors=\"pt\", truncation=True, padding='max_length', max_length=128)\n",
103
- " outputs = model(**inputs)\n",
104
- " predictions = outputs.logits.argmax(dim=-1).tolist()\n",
105
- " return predictions\n",
106
- "\n",
107
- "# Apply the predictions to the dataset\n",
108
- "dataset['predicted_label'] = predict(dataset['text'].tolist())\n",
109
- "\n",
110
- "# Map the predicted labels back to the response texts\n",
111
- "reverse_label_mapping = {v: k for k, v in label_mapping.items()}\n",
112
- "dataset['predicted_label'] = dataset['predicted_label'].map(reverse_label_mapping)\n",
113
- "\n",
114
- "# Save the results\n",
115
- "dataset.to_csv('gemma-2-9b_predicted_results.csv', index=False)"
116
- ]
117
- }
118
- ],
119
- "metadata": {
120
- "kernelspec": {
121
- "display_name": "base",
122
- "language": "python",
123
- "name": "python3"
124
- },
125
- "language_info": {
126
- "name": "python",
127
- "version": "3.11.0"
128
- }
129
- },
130
- "nbformat": 4,
131
- "nbformat_minor": 2
132
- }