{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30747,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import pandas as pd\n\nsplits = {'train': 'data/train-00000-of-00001.parquet', \n 'validation_matched': 'data/validation_matched-00000-of-00001.parquet', \n 'validation_mismatched': 'data/validation_mismatched-00000-of-00001.parquet'}\n \ndf = pd.read_parquet(\"hf://datasets/nyu-mll/multi_nli/\" + splits[\"train\"])","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2024-07-20T13:04:43.497190Z","iopub.execute_input":"2024-07-20T13:04:43.497536Z","iopub.status.idle":"2024-07-20T13:04:51.222716Z","shell.execute_reply.started":"2024-07-20T13:04:43.497506Z","shell.execute_reply":"2024-07-20T13:04:51.221916Z"},"trusted":true},"execution_count":1,"outputs":[]},{"cell_type":"code","source":"df = df[['label', 'premise', 'hypothesis']].iloc[:13000]\ndf","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:52.891990Z","iopub.execute_input":"2024-07-20T13:04:52.892933Z","iopub.status.idle":"2024-07-20T13:04:53.013716Z","shell.execute_reply.started":"2024-07-20T13:04:52.892902Z","shell.execute_reply":"2024-07-20T13:04:53.012747Z"},"trusted":true},"execution_count":2,"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":" label premise \\\n0 1 Conceptually cream skimming has two basic dime... \n1 0 you know during the season and i guess at at y... \n2 0 One of our number will carry out your instruct... \n3 0 How do you know? All this is their information... \n4 1 yeah i tell you what though if you go price so... \n... ... ... \n12995 1 right you have to question you have to wonder ... \n12996 2 Reviewers may not be familiar with the charact... \n12997 1 yeah it was Twins was good too because when i... \n12998 0 The Jews are Neanderthals. \n12999 1 25--to get a copy of my book legally from my W... \n\n hypothesis \n0 Product and geography are what make cream skim... \n1 You lose the things to the following level if ... \n2 A member of my team will execute your orders w... \n3 This information belongs to them. \n4 The tennis shoes have a range of prices. \n... ... \n12995 I would not mind living on an island to find out. \n12996 Typically, reviewers are fully aware of an eme... \n12997 Twins was the best movie I saw last year. \n12998 Jewish people are like Neanderthals. \n12999 My book is free on my site. \n\n[13000 rows x 3 columns]","text/html":"
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
labelpremisehypothesis
01Conceptually cream skimming has two basic dime...Product and geography are what make cream skim...
10you know during the season and i guess at at y...You lose the things to the following level if ...
20One of our number will carry out your instruct...A member of my team will execute your orders w...
30How do you know? All this is their information...This information belongs to them.
41yeah i tell you what though if you go price so...The tennis shoes have a range of prices.
............
129951right you have to question you have to wonder ...I would not mind living on an island to find out.
129962Reviewers may not be familiar with the charact...Typically, reviewers are fully aware of an eme...
129971yeah it was Twins was good too because when i...Twins was the best movie I saw last year.
129980The Jews are Neanderthals.Jewish people are like Neanderthals.
12999125--to get a copy of my book legally from my W...My book is free on my site.
\n

13000 rows × 3 columns

\n
"},"metadata":{}}]},{"cell_type":"code","source":"import torch\nfrom torch.utils.data import Dataset, TensorDataset, DataLoader\nfrom torch.nn.utils.rnn import pad_sequence\nimport pickle\nimport os\nfrom transformers import BertTokenizer\nfrom sklearn.model_selection import train_test_split","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:53.472075Z","iopub.execute_input":"2024-07-20T13:04:53.472883Z","iopub.status.idle":"2024-07-20T13:04:57.918248Z","shell.execute_reply.started":"2024-07-20T13:04:53.472853Z","shell.execute_reply":"2024-07-20T13:04:57.917380Z"},"trusted":true},"execution_count":3,"outputs":[]},{"cell_type":"code","source":"class MNLIDataBert(Dataset):\n\n def __init__(self, train_df, val_df):\n\n self.train_df = train_df\n self.val_df = val_df\n\n self.base_path = '/content/'\n self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # Using a pre-trained BERT tokenizer to encode sentences\n self.train_data = None\n self.val_data = None\n self.init_data()\n\n def init_data(self):\n self.train_data = self.load_data(self.train_df)\n self.val_data = self.load_data(self.val_df)\n\n def load_data(self, df):\n MAX_LEN = 512\n token_ids = []\n mask_ids = []\n seg_ids = []\n y = []\n\n premise_list = df['premise'].to_list()\n hypothesis_list = df['hypothesis'].to_list()\n label_list = df['label'].to_list()\n\n for (premise, hypothesis, label) in zip(premise_list, hypothesis_list, label_list):\n premise_id = self.tokenizer.encode(premise, add_special_tokens = False)\n hypothesis_id = self.tokenizer.encode(hypothesis, add_special_tokens = False)\n pair_token_ids = [self.tokenizer.cls_token_id] + premise_id + [self.tokenizer.sep_token_id] + hypothesis_id + [self.tokenizer.sep_token_id]\n premise_len = len(premise_id)\n hypothesis_len = len(hypothesis_id)\n\n segment_ids = torch.tensor([0] * (premise_len + 2) + [1] * (hypothesis_len + 1)) # premise and hypothesis \n attention_mask_ids = torch.tensor([1] * (premise_len + hypothesis_len + 3)) # mask padded values\n\n token_ids.append(torch.tensor(pair_token_ids))\n seg_ids.append(segment_ids)\n mask_ids.append(attention_mask_ids)\n y.append(label)\n \n token_ids = pad_sequence(token_ids, batch_first=True)\n mask_ids = pad_sequence(mask_ids, batch_first=True)\n seg_ids = pad_sequence(seg_ids, batch_first=True)\n y = torch.tensor(y)\n dataset = TensorDataset(token_ids, mask_ids, seg_ids, y)\n print(len(dataset))\n return dataset\n\n def get_data_loaders(self, batch_size=32, shuffle=True):\n train_loader = DataLoader(\n self.train_data,\n shuffle=shuffle,\n batch_size=batch_size\n )\n\n val_loader = DataLoader(\n self.val_data,\n shuffle=shuffle,\n batch_size=batch_size\n )\n\n return train_loader, val_loader","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.919963Z","iopub.execute_input":"2024-07-20T13:04:57.920822Z","iopub.status.idle":"2024-07-20T13:04:57.935843Z","shell.execute_reply.started":"2024-07-20T13:04:57.920788Z","shell.execute_reply":"2024-07-20T13:04:57.934836Z"},"trusted":true},"execution_count":4,"outputs":[]},{"cell_type":"code","source":"train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.937187Z","iopub.execute_input":"2024-07-20T13:04:57.937623Z","iopub.status.idle":"2024-07-20T13:04:57.953101Z","shell.execute_reply.started":"2024-07-20T13:04:57.937589Z","shell.execute_reply":"2024-07-20T13:04:57.952066Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"code","source":"val_df","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.955406Z","iopub.execute_input":"2024-07-20T13:04:57.955788Z","iopub.status.idle":"2024-07-20T13:04:57.967894Z","shell.execute_reply.started":"2024-07-20T13:04:57.955763Z","shell.execute_reply":"2024-07-20T13:04:57.966921Z"},"trusted":true},"execution_count":6,"outputs":[{"execution_count":6,"output_type":"execute_result","data":{"text/plain":" label premise \\\n3615 2 An ambitious plan for a hexagonally based chur... \n2536 1 for for city use and \n5397 0 The really valuable estate cannot be touched b... \n9982 0 isn't that the truth it's funny in fact it's i... \n1498 0 Most drivers will be able to point out the Bok... \n... ... ... \n11872 1 As the road rises, the rugged countryside beco... \n9264 0 The monastery rests in a fertile valley and is... \n7277 2 Since everyone who matters presumably knows al... \n3752 2 so what type of restaurant do you like \n6292 2 right that that's actually the part that that ... \n\n hypothesis \n3615 The complex plan of the hospital, came to noth... \n2536 Only the city can use it. \n5397 The death tax is unable to reach the most impo... \n9982 I love that music from my childhood has return... \n1498 The Bok House is transformed into a restaurant. \n... ... \n11872 The hillsides are full of ferns and trees. \n9264 In a fertile valley surrounded by plane and pi... \n7277 People who matter no nothing about who backs t... \n3752 You don't eat at restaurants at all, right? \n6292 I don't believe muslims are hated by Israelis. \n\n[2600 rows x 3 columns]","text/html":"
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
labelpremisehypothesis
36152An ambitious plan for a hexagonally based chur...The complex plan of the hospital, came to noth...
25361for for city use andOnly the city can use it.
53970The really valuable estate cannot be touched b...The death tax is unable to reach the most impo...
99820isn't that the truth it's funny in fact it's i...I love that music from my childhood has return...
14980Most drivers will be able to point out the Bok...The Bok House is transformed into a restaurant.
............
118721As the road rises, the rugged countryside beco...The hillsides are full of ferns and trees.
92640The monastery rests in a fertile valley and is...In a fertile valley surrounded by plane and pi...
72772Since everyone who matters presumably knows al...People who matter no nothing about who backs t...
37522so what type of restaurant do you likeYou don't eat at restaurants at all, right?
62922right that that's actually the part that that ...I don't believe muslims are hated by Israelis.
\n

2600 rows × 3 columns

\n
"},"metadata":{}}]},{"cell_type":"code","source":"mnli_dataset = MNLIDataBert(train_df, val_df)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:04:57.969018Z","iopub.execute_input":"2024-07-20T13:04:57.969361Z","iopub.status.idle":"2024-07-20T13:05:17.499250Z","shell.execute_reply.started":"2024-07-20T13:04:57.969331Z","shell.execute_reply":"2024-07-20T13:05:17.498209Z"},"trusted":true},"execution_count":7,"outputs":[{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/48.0 [00:002}:{:0>2}:{:05.2f}\".format(int(hours),int(minutes),seconds))","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:21.782196Z","iopub.execute_input":"2024-07-20T13:05:21.783344Z","iopub.status.idle":"2024-07-20T13:05:21.797911Z","shell.execute_reply.started":"2024-07-20T13:05:21.783317Z","shell.execute_reply":"2024-07-20T13:05:21.797059Z"},"trusted":true},"execution_count":12,"outputs":[]},{"cell_type":"code","source":"train(model, train_loader, val_loader, optimizer)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:05:21.799438Z","iopub.execute_input":"2024-07-20T13:05:21.800060Z","iopub.status.idle":"2024-07-20T13:21:51.983474Z","shell.execute_reply.started":"2024-07-20T13:05:21.800009Z","shell.execute_reply":"2024-07-20T13:21:51.982480Z"},"trusted":true},"execution_count":13,"outputs":[{"name":"stdout","text":"Epoch 1: train_loss: 0.8012 train_acc: 0.6405 | val_loss: 0.6349 val_acc: 0.7367\n00:08:12.58\nEpoch 2: train_loss: 0.4223 train_acc: 0.8425 | val_loss: 0.6711 val_acc: 0.7416\n00:08:17.60\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\nfrom transformers import BertTokenizer\nimport torch.nn.functional as F\n\nmodel.eval()\n\n# Load the tokenizer\ntokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n\n# Function to predict entailment for a single premise-hypothesis pair\ndef predict_entailment(premise, hypothesis):\n # Tokenize and encode the inputs\n premise_id = tokenizer.encode(premise, add_special_tokens=False)\n hypothesis_id = tokenizer.encode(hypothesis, add_special_tokens=False)\n pair_token_ids = [tokenizer.cls_token_id] + premise_id + [tokenizer.sep_token_id] + hypothesis_id + [tokenizer.sep_token_id]\n \n segment_ids = torch.tensor([0] * (len(premise_id) + 2) + [1] * (len(hypothesis_id) + 1)).unsqueeze(0) # Add batch dimension\n attention_mask_ids = torch.tensor([1] * (len(premise_id) + len(hypothesis_id) + 3)).unsqueeze(0) # Add batch dimension\n token_ids = torch.tensor(pair_token_ids).unsqueeze(0) # Add batch dimension\n \n # Move to device\n token_ids = token_ids.to(device)\n segment_ids = segment_ids.to(device)\n attention_mask_ids = attention_mask_ids.to(device)\n \n # Run the model\n with torch.no_grad():\n outputs = model(token_ids, token_type_ids=segment_ids, attention_mask=attention_mask_ids)\n logits = outputs.logits\n \n # Apply softmax to get probabilities\n probs = F.softmax(logits, dim=1)\n \n # Get the predicted label\n predicted_label = torch.argmax(probs, dim=1).item()\n \n return predicted_label, probs","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:51.984900Z","iopub.execute_input":"2024-07-20T13:21:51.985354Z","iopub.status.idle":"2024-07-20T13:21:52.134564Z","shell.execute_reply.started":"2024-07-20T13:21:51.985319Z","shell.execute_reply":"2024-07-20T13:21:52.133707Z"},"trusted":true},"execution_count":14,"outputs":[]},{"cell_type":"code","source":"label_map = {0: 'Entailment', 1: 'Neutral', 2: 'Contradiction'}","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.136154Z","iopub.execute_input":"2024-07-20T13:21:52.136726Z","iopub.status.idle":"2024-07-20T13:21:52.141308Z","shell.execute_reply.started":"2024-07-20T13:21:52.136692Z","shell.execute_reply":"2024-07-20T13:21:52.140368Z"},"trusted":true},"execution_count":15,"outputs":[]},{"cell_type":"code","source":"# Example premises and hypotheses\npremises = [\n \"A man is playing a guitar.\",\n \"Laura likes to go to restaurants every weekend.\",\n \"Messi is a proffesional football player.\"\n]\n\nhypotheses = [\n \"A person is making music.\",\n \"Laura doesn't eat at restaurants at all.\",\n \"Akash is doing his homework.\"\n]","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.142479Z","iopub.execute_input":"2024-07-20T13:21:52.142762Z","iopub.status.idle":"2024-07-20T13:21:52.150447Z","shell.execute_reply.started":"2024-07-20T13:21:52.142738Z","shell.execute_reply":"2024-07-20T13:21:52.149469Z"},"trusted":true},"execution_count":16,"outputs":[]},{"cell_type":"code","source":"\n# Predict entailment for each pair\nfor premise, hypothesis in zip(premises, hypotheses):\n label, probs = predict_entailment(premise, hypothesis)\n print(f\"Premise: {premise}\")\n print(f\"Hypothesis: {hypothesis}\")\n print(f\"Predicted label: {label_map[label]}\")\n print(f\"Probabilities: {probs}\")\n print('-'*80)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.151433Z","iopub.execute_input":"2024-07-20T13:21:52.151688Z","iopub.status.idle":"2024-07-20T13:21:52.340741Z","shell.execute_reply.started":"2024-07-20T13:21:52.151665Z","shell.execute_reply":"2024-07-20T13:21:52.339852Z"},"trusted":true},"execution_count":17,"outputs":[{"name":"stdout","text":"Premise: A man is playing a guitar.\nHypothesis: A person is making music.\nPredicted label: Entailment\nProbabilities: tensor([[0.9668, 0.0200, 0.0132]], device='cuda:0')\n--------------------------------------------------------------------------------\nPremise: Laura likes to go to restaurants every weekend.\nHypothesis: Laura doesn't eat at restaurants at all.\nPredicted label: Contradiction\nProbabilities: tensor([[0.0016, 0.0022, 0.9962]], device='cuda:0')\n--------------------------------------------------------------------------------\nPremise: Messi is a proffesional football player.\nHypothesis: Akash is doing his homework.\nPredicted label: Neutral\nProbabilities: tensor([[0.0153, 0.6406, 0.3441]], device='cuda:0')\n--------------------------------------------------------------------------------\n","output_type":"stream"}]},{"cell_type":"code","source":"model_path = \"./ema_task_model\"\ntokenizer_path = \"./ema_task_tokenizer\"\n\n# Save the model and tokenizer\nmodel.save_pretrained(model_path)\ntokenizer.save_pretrained(tokenizer_path)","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:52.344190Z","iopub.execute_input":"2024-07-20T13:21:52.344492Z","iopub.status.idle":"2024-07-20T13:21:53.357043Z","shell.execute_reply.started":"2024-07-20T13:21:52.344467Z","shell.execute_reply":"2024-07-20T13:21:53.356070Z"},"trusted":true},"execution_count":18,"outputs":[{"execution_count":18,"output_type":"execute_result","data":{"text/plain":"('./ema_task_tokenizer/tokenizer_config.json',\n './ema_task_tokenizer/special_tokens_map.json',\n './ema_task_tokenizer/vocab.txt',\n './ema_task_tokenizer/added_tokens.json')"},"metadata":{}}]},{"cell_type":"code","source":"!zip -r ema_task_model.zip ema_task_model\n!zip -r ema_task_tokenizer.zip ema_task_tokenizer","metadata":{"execution":{"iopub.status.busy":"2024-07-20T13:21:53.358178Z","iopub.execute_input":"2024-07-20T13:21:53.358464Z","iopub.status.idle":"2024-07-20T13:22:18.652819Z","shell.execute_reply.started":"2024-07-20T13:21:53.358439Z","shell.execute_reply":"2024-07-20T13:22:18.651602Z"},"trusted":true},"execution_count":19,"outputs":[{"name":"stdout","text":" adding: ema_task_model/ (stored 0%)\n adding: ema_task_model/config.json (deflated 51%)\n adding: ema_task_model/model.safetensors (deflated 7%)\n adding: ema_task_tokenizer/ (stored 0%)\n adding: ema_task_tokenizer/vocab.txt (deflated 53%)\n adding: ema_task_tokenizer/special_tokens_map.json (deflated 42%)\n adding: ema_task_tokenizer/tokenizer_config.json (deflated 75%)\n","output_type":"stream"}]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}