{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import RobertaModel, RobertaTokenizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Define the model architecture\n", "class RobertaClass(torch.nn.Module):\n", " def __init__(self):\n", " super(RobertaClass, self).__init__()\n", " self.l1 = RobertaModel.from_pretrained(\"roberta-base\")\n", " self.pre_classifier = torch.nn.Linear(768, 768)\n", " self.dropout = torch.nn.Dropout(0.3)\n", " self.classifier = torch.nn.Linear(768, 5) # Assuming 5 classes\n", "\n", " def forward(self, input_ids, attention_mask, token_type_ids):\n", " output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n", " hidden_state = output_1[0]\n", " pooler = hidden_state[:, 0]\n", " pooler = self.pre_classifier(pooler)\n", " pooler = torch.nn.ReLU()(pooler)\n", " pooler = self.dropout(pooler)\n", " output = self.classifier(pooler)\n", " return output" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "# Instantiate the model\n", "model = RobertaClass()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the state dictionary\n", "model.load_state_dict(torch.load(r'C:\\Users\\Jash\\OneDrive\\Desktop\\finalmodel2\\pytorch_roberta_sentiment.bin', map_location=torch.device('cpu')))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Define your inference function to accept multiple inputs\n", "def predict(texts):\n", " tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n", " inputs = tokenizer.batch_encode_plus(\n", " texts,\n", " add_special_tokens=True,\n", " max_length=256,\n", " pad_to_max_length=True,\n", " return_token_type_ids=True,\n", " return_tensors='pt',\n", " truncation=True\n", " )\n", " with torch.no_grad():\n", " outputs = model(inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids'])\n", " _, predicted = torch.max(outputs, 1)\n", " return predicted.tolist()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\raju_\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\tokenization_utils_base.py:2619: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicted Labels: [0, 1, 1]\n" ] } ], "source": [ "# Test your model with multiple inputs'\n", "texts_to_predict = [\"check us out\",\"hurry up only 10 pieces left\", \"Congrats, you have won a 50 million dollars, sign up to redeem\"]\n", "predicted_labels = predict(texts_to_predict)\n", "print(\"Predicted Labels:\", predicted_labels)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1" } }, "nbformat": 4, "nbformat_minor": 2 }