Ryan Kim commited on
Commit
548d450
β€’
1 Parent(s): fda5a48

final push

Browse files
Files changed (3) hide show
  1. README.md +35 -1
  2. src/main.py +1 -6
  3. src/val.ipynb +1 -0
README.md CHANGED
@@ -13,6 +13,8 @@ pinned: false
13
 
14
  ## Milestone 4
15
 
 
 
16
  ### **Code Breakdown**
17
 
18
  The USPTO application is divided into several directories. Overall, the important files are present in the application as such:
@@ -24,9 +26,10 @@ The USPTO application is divided into several directories. Overall, the importan
24
  - src/
25
  - main.py
26
  - train.ipynb
 
27
  ````
28
 
29
- Both `train.json` and `val.json` contain the original USPTO data, sized down to contain only the relevant data from each recorded patent and split between training and validation data. The validation data `val.json` is used in the online USPTO application as a set of pre-set patents that a user can select when using the USPTO patent prediction function.
30
 
31
  The primary code back-end is stored in `main.py`, which runs the application on the HuggingFace space UI. The application uses **Streamlit** to render UI elements on the screen. All models run off of Transformers and Tokenizers from **HuggingFace**.
32
 
@@ -348,6 +351,37 @@ print("=== TRAINING CLAIMS ===")
348
  Train(train_claims_loader,upsto_claims_model_path, num_train_epochs=10)
349
  ````
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  ---
352
 
353
  ## Milestone 3
13
 
14
  ## Milestone 4
15
 
16
+ [![Project Video](https://img.youtube.com/vi/csSGBwIE7nk/0.jpg)](https://youtu.be/csSGBwIE7nk "Project Video")
17
+
18
  ### **Code Breakdown**
19
 
20
  The USPTO application is divided into several directories. Overall, the important files are present in the application as such:
26
  - src/
27
  - main.py
28
  - train.ipynb
29
+ - val.ipynb
30
  ````
31
 
32
+ Both `train.json` and `val.json` contain the original USPTO data, sized down to contain only the relevant data from each recorded patent and split between training and validation data. The validation data `val.json` is used in the online USPTO application as a set of pre-set patents that a user can select when using the USPTO patent prediction function. That, and the `val.ipynb` file was used to validate the model's accuracy.
33
 
34
  The primary code back-end is stored in `main.py`, which runs the application on the HuggingFace space UI. The application uses **Streamlit** to render UI elements on the screen. All models run off of Transformers and Tokenizers from **HuggingFace**.
35
 
351
  Train(train_claims_loader,upsto_claims_model_path, num_train_epochs=10)
352
  ````
353
 
354
+ ### Evaluating the Models
355
+
356
+ #### Sentiment Analysis
357
+
358
+ There isn't an effective way to validate the sentiment analysis models, as they are publicly available models and it is unknown what data they were explicitly trained on. Therefore, evaluation will rely on anecdotal testing.
359
+
360
+ The sentiment models that appear to work the best are the [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) and - [siebert/sentiment-roberta-large-english](https://huggingface.co/siebert/sentiment-roberta-large-english) models, with few caveats. These two models generally perform very well at detecting sentiment in mid to long expressions. However the [siebert/sentiment-roberta-large-english](https://huggingface.co/siebert/sentiment-roberta-large-english) model tends to suffer when expressions are shorter and less complex in lexicon. Even the [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) model suffers from time to time if not enough context has been provided.
361
+
362
+ The model that performed the worst is the [finiteautomata/beto-sentiment-analysis](https://huggingface.co/finiteautomata/beto-sentiment-analysis). This model seems to have the worst time trying to interpret meaning from sentences, even with strongly worded language such as "hate". For example, the expression _"I hate you" returns a **NEUTRAL** response with 99.6% confidence, which differs from the [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) and - [siebert/sentiment-roberta-large-english](https://huggingface.co/siebert/sentiment-roberta-large-english) models (**NEGATIVE**: ~96.5% - ~99.9% accuracy respectively). It appears that the [finiteautomata/beto-sentiment-analysis](https://huggingface.co/finiteautomata/beto-sentiment-analysis) gets confused when not enough context is provided. The expression "I hate you because you hurt my family" manages to return a **NEGATIVE** label, but with a mere 87.7% confidence.
363
+
364
+ The unique model is the [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) model, which instead gives 6 general emotions as opposed to a binary **NEGATIVE** or **POSITIVE** rating:
365
+
366
+ * Sadness
367
+ * Joy
368
+ * Love
369
+ * Anger
370
+ * Fear
371
+ * Surprise
372
+
373
+ This one offers more nuance to each evaluation, but while the general distinction between "positive" and "negative" emotions is often accurate, the specific emotion applied to the statement is sometimes confusing. For example, the expression "I love you because you saved my family" is labeled as **JOY** instead of **LOVE**.
374
+
375
+ Overall, all models are able to perform to some level of success, but the context of each input needs to be made clear in the wording for the models to produce accurate responses. These models seem to suffer when not enough context is provided, sometimes blatantly giving incorrect responses to situations that are too simple.
376
+
377
+ #### Patent Acceptance Prediction
378
+
379
+ With access to labeled validation data, the fine-tuned models can be ranked in accuracy. Sample code that does so is provided in `src/val.ipynb` and evaluates 1000 random data samples out of 4000+ samples inside of `src/val.ipynb`.
380
+
381
+ Overall, both the fine-tuned Abstract model and the Claims model seem to perform very similarly, producing accuracy rates of 72.89% and 72.8% accuracy respectively out of 1000 random samples. The aggregated softmax labeling, assuming equal weighting between the two models, produces 76.2% accuracy. Depending on who you ask, this performance can be discouraging or encouraging.
382
+
383
+ The 76% accuracy rating being higher than the individual fine-tuned models is interesting, as it implies that the accuracy between the two models is not consistent; there are times when one of the models might outperform the other in certain situations. If the accuracy remained the same across all 1000 random samples, then the aggregated prediction accuracy rate would be within the same range as the individual fine-tuned models.
384
+
385
  ---
386
 
387
  ## Milestone 3
src/main.py CHANGED
@@ -33,8 +33,6 @@ class ModelImplementation(object):
33
  self.tokenizer = tokenizer_func.from_pretrained(self.tokenizer_model_name)
34
  self.classifier = pipeline_func(model=self.model, tokenizer=self.tokenizer, padding=True, truncation=True, **classifier_args)
35
  self.parser = parser_func
36
-
37
- self.history = []
38
 
39
  def predict(self, val):
40
  result = self.classifier(val)
@@ -97,10 +95,7 @@ def emotion_model_change():
97
  classifier_args={ "task" : "sentiment-analysis" },
98
  placeholders=["@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."]
99
  )
100
-
101
- if "page" not in st.session_state:
102
- st.session_state.page = "home"
103
-
104
  if "emotion_model_name" not in st.session_state:
105
  st.session_state.emotion_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
106
  emotion_model_change()
33
  self.tokenizer = tokenizer_func.from_pretrained(self.tokenizer_model_name)
34
  self.classifier = pipeline_func(model=self.model, tokenizer=self.tokenizer, padding=True, truncation=True, **classifier_args)
35
  self.parser = parser_func
 
 
36
 
37
  def predict(self, val):
38
  result = self.classifier(val)
95
  classifier_args={ "task" : "sentiment-analysis" },
96
  placeholders=["@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence."]
97
  )
98
+
 
 
 
99
  if "emotion_model_name" not in st.session_state:
100
  st.session_state.emotion_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
101
  emotion_model_change()
src/val.ipynb ADDED
@@ -0,0 +1 @@
 
1
+ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyPbIO5QK/V8keB7h6h+8Ju2"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":22,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ePuwhQ7QyzUW","executionInfo":{"status":"ok","timestamp":1682571700367,"user_tz":240,"elapsed":29378,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"9c939d4a-7622-4c48-ba58-b83162400692"},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: datasets in /usr/local/lib/python3.9/dist-packages (2.11.0)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (6.0)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (4.65.0)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from datasets) (1.5.3)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets) (3.2.0)\n","Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.18.0)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from datasets) (1.22.4)\n","Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.14.1)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.9/dist-packages (from datasets) (0.70.14)\n","Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.3.6)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (9.0.0)\n","Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (2023.4.0)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (2.27.1)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets) (3.8.4)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (23.1.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.3.3)\n","Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (2.0.12)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (6.0.4)\n","Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (4.0.2)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (1.9.2)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.12.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (3.4)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2022.7.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: streamlit in /usr/local/lib/python3.9/dist-packages (1.21.0)\n","Requirement already satisfied: packaging>=14.1 in /usr/local/lib/python3.9/dist-packages (from streamlit) (23.1)\n","Requirement already satisfied: toml in /usr/local/lib/python3.9/dist-packages (from streamlit) (0.10.2)\n","Requirement already satisfied: tzlocal>=1.1 in /usr/local/lib/python3.9/dist-packages (from streamlit) (4.3)\n","Requirement already satisfied: protobuf<4,>=3.12 in /usr/local/lib/python3.9/dist-packages (from streamlit) (3.20.3)\n","Requirement already satisfied: importlib-metadata>=1.4 in /usr/local/lib/python3.9/dist-packages (from streamlit) (6.6.0)\n","Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (13.3.4)\n","Requirement already satisfied: python-dateutil in /usr/local/lib/python3.9/dist-packages (from streamlit) (2.8.2)\n","Requirement already satisfied: pympler>=0.9 in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.0.1)\n","Requirement already satisfied: pandas<2,>=0.25 in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.5.3)\n","Requirement already satisfied: typing-extensions>=3.10.0.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (4.5.0)\n","Requirement already satisfied: validators>=0.2 in /usr/local/lib/python3.9/dist-packages (from streamlit) (0.20.0)\n","Requirement already satisfied: blinker>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.6.2)\n","Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (8.4.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from streamlit) (1.22.4)\n","Requirement already satisfied: tornado>=6.0.3 in /usr/local/lib/python3.9/dist-packages (from streamlit) (6.2)\n","Requirement already satisfied: watchdog in /usr/local/lib/python3.9/dist-packages (from streamlit) (3.0.0)\n","Requirement already satisfied: gitpython!=3.1.19 in /usr/local/lib/python3.9/dist-packages (from streamlit) (3.1.31)\n","Requirement already satisfied: pydeck>=0.1.dev5 in /usr/local/lib/python3.9/dist-packages (from streamlit) (0.8.1b0)\n","Requirement already satisfied: requests>=2.4 in /usr/local/lib/python3.9/dist-packages (from streamlit) (2.27.1)\n","Requirement already satisfied: cachetools>=4.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (5.3.0)\n","Requirement already satisfied: altair<5,>=3.2.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (4.2.2)\n","Requirement already satisfied: pyarrow>=4.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (9.0.0)\n","Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.9/dist-packages (from streamlit) (8.1.3)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (3.1.2)\n","Requirement already satisfied: toolz in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (0.12.0)\n","Requirement already satisfied: entrypoints in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (0.4)\n","Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.9/dist-packages (from altair<5,>=3.2.0->streamlit) (4.3.3)\n","Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.9/dist-packages (from gitpython!=3.1.19->streamlit) (4.0.10)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=1.4->streamlit) (3.15.0)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas<2,>=0.25->streamlit) (2022.7.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil->streamlit) (1.16.0)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (2.0.12)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (2022.12.7)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (1.26.15)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.4->streamlit) (3.4)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.9/dist-packages (from rich>=10.11.0->streamlit) (2.14.0)\n","Requirement already satisfied: markdown-it-py<3.0.0,>=2.2.0 in /usr/local/lib/python3.9/dist-packages (from rich>=10.11.0->streamlit) (2.2.0)\n","Requirement already satisfied: pytz-deprecation-shim in /usr/local/lib/python3.9/dist-packages (from tzlocal>=1.1->streamlit) (0.1.0.post0)\n","Requirement already satisfied: decorator>=3.4.0 in /usr/local/lib/python3.9/dist-packages (from validators>=0.2->streamlit) (4.4.2)\n","Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.9/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.19->streamlit) (5.0.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->altair<5,>=3.2.0->streamlit) (2.1.2)\n","Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.9/dist-packages (from jsonschema>=3.0->altair<5,>=3.2.0->streamlit) (23.1.0)\n","Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.9/dist-packages (from jsonschema>=3.0->altair<5,>=3.2.0->streamlit) (0.19.3)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py<3.0.0,>=2.2.0->rich>=10.11.0->streamlit) (0.1.2)\n","Requirement already satisfied: tzdata in /usr/local/lib/python3.9/dist-packages (from pytz-deprecation-shim->tzlocal>=1.1->streamlit) (2023.3)\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: transformers in /usr/local/lib/python3.9/dist-packages (4.28.1)\n","Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers) (3.12.0)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.14.1)\n","Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from transformers) (2.27.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (2023.4.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (3.4)\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (4.65.0)\n"]}],"source":["!pip install datasets\n","!pip install streamlit\n","!pip install transformers\n","!pip install tqdm"]},{"cell_type":"code","source":["from datasets import load_dataset\n","import pandas as pd\n","import numpy as np\n","import os\n","import json\n","import torch\n","import sys\n","from tqdm import tqdm\n","\n","import streamlit as st\n","from transformers import TextClassificationPipeline, pipeline\n","from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertTokenizerFast, DistilBertForSequenceClassification"],"metadata":{"id":"xqhKMsNVzBtY","executionInfo":{"status":"ok","timestamp":1682571793784,"user_tz":240,"elapsed":3,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}}},"execution_count":27,"outputs":[]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/gdrive')"],"metadata":{"id":"4E_xZUUwzGJm","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1682570070672,"user_tz":240,"elapsed":23530,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"a6fbb01a-caeb-4dc5-bef1-837c5dce202f"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n"]}]},{"cell_type":"code","source":["abstract_model = TextClassificationPipeline(\n"," model = DistilBertForSequenceClassification.from_pretrained('rk2546/uspto-patents-abstracts'),\n"," tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased'),\n"," padding = True, \n"," truncation = True,\n"," return_all_scores = True\n",")\n","\n","claim_model = TextClassificationPipeline(\n"," model = DistilBertForSequenceClassification.from_pretrained('rk2546/uspto-patents-claims'),\n"," tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased'),\n"," padding = True, \n"," truncation = True,\n"," return_all_scores = True\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Mj3hQGRU90bA","executionInfo":{"status":"ok","timestamp":1682573368942,"user_tz":240,"elapsed":7417,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"05cc8f93-1c72-4880-ae76-8d132d500c5f"},"execution_count":39,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.9/dist-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar funcionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n"," warnings.warn(\n"]}]},{"cell_type":"code","source":["path_to_valData = \"./gdrive/MyDrive/AI [Spring 2023]/cs-gy-6613-project-rk2546/val.json\"\n","f = open(path_to_valData)\n","valData = json.load(f)\n","f.close()"],"metadata":{"id":"0oimA5tO9c1G","executionInfo":{"status":"ok","timestamp":1682570188049,"user_tz":240,"elapsed":1507,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# We track the successes of abstracts, claims, and both combined\n","abstract_successes = 0\n","claim_successes = 0\n","aggregate_successes = 0\n","total_num = len(valData['labels'])\n","\n","# By default, we weigh the claims more highly than abstracts\n","claim_weight = 0.5\n","abstract_weight = 0.5\n","\n","# To randomize the data, we generate random indices \n","index_perms = np.random.permutation(total_num)\n","labels = []\n","abstracts = []\n","claims = []\n","# We generate up to 500 samples to validate against\n","new_total_num = min(1000,len(index_perms))\n","for i in range(new_total_num):\n"," labels.append(valData['labels'][index_perms[i]])\n"," abstracts.append(valData['abstracts'][index_perms[i]])\n"," claims.append(valData['claims'][index_perms[i]])\n","\n","# Now we validate\n","for i in tqdm(range(new_total_num)):\n"," label = labels[i]\n"," abstract = abstracts[i]\n"," claim = claims[i]\n","\n"," abstract_response = abstract_model(abstract)[0]\n"," claim_response = claim_model(claim)[0]\n"," aggregate_response = [\n"," {'label':'REJECTED','score':abstract_response[0]['score']*abstract_weight + claim_response[0]['score']*claim_weight},\n"," {'label':'ACCEPTED','score':abstract_response[1]['score']*abstract_weight + claim_response[1]['score']*claim_weight}\n"," ]\n","\n"," abstract_sorted = sorted(abstract_response, key=lambda d: d['score'], reverse=True) \n"," claim_sorted = sorted(claim_response, key=lambda d: d['score'], reverse=True)\n"," aggregate_sorted = sorted(aggregate_response, key=lambda d: d['score'], reverse=True) \n","\n"," if abstract_sorted[0]['label'] == 'LABEL_1' and label == 1:\n"," abstract_successes += 1\n"," elif abstract_sorted[0]['label'] == 'LABEL_0' and label == 0:\n"," abstract_successes += 1\n"," \n"," if claim_sorted[0]['label'] == 'LABEL_1' and label == 1:\n"," claim_successes += 1\n"," elif claim_sorted[0]['label'] == 'LABEL_0' and label == 0:\n"," claim_successes += 1\n"," \n"," if aggregate_sorted[0]['label'] == 'ACCEPTED' and label == 1:\n"," aggregate_successes += 1\n"," elif aggregate_sorted[0]['label'] == 'REJECTED' and label == 0:\n"," aggregate_successes += 1\n","\n"," # At 10% intervals, we print the current results\n"," if i > 0 and i % (new_total_num * 0.1) == 0:\n"," print(f\"\\nAbs: {abstract_successes}/{i} | Cl: {claim_successes}/{i} | Agg: {aggregate_successes}/{i}\")\n","\n","# Calculate final accuracy\n","abstract_accuracy = abstract_successes / new_total_num\n","claim_accuracy = claim_successes / new_total_num\n","aggregate_accuracy = aggregate_successes / new_total_num\n","\n","# Display accuracy\n","print(\"\\n\")\n","print(f\"Abstract Model Accuracy: {abstract_accuracy * 100}%\")\n","print(f\"Claim Model Accuracy: {claim_accuracy * 100}%\")\n","print(f\"Aggregated Model Accuracy: {aggregate_accuracy * 100}%\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"FLE-9qlw9qW7","executionInfo":{"status":"ok","timestamp":1682577092672,"user_tz":240,"elapsed":1356393,"user":{"displayName":"Ryan Kim","userId":"18356277368138721144"}},"outputId":"fe0fb5ed-b075-4e4d-c616-6dbac5148a75"},"execution_count":48,"outputs":[{"output_type":"stream","name":"stderr","text":[" 10%|β–ˆ | 101/1000 [02:25<22:03, 1.47s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 70/100 | Cl: 73/100 | Agg: 73/100\n"]},{"output_type":"stream","name":"stderr","text":[" 20%|β–ˆβ–ˆ | 201/1000 [04:38<21:25, 1.61s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 148/200 | Cl: 155/200 | Agg: 155/200\n"]},{"output_type":"stream","name":"stderr","text":[" 30%|β–ˆβ–ˆβ–ˆ | 301/1000 [06:53<13:59, 1.20s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 220/300 | Cl: 224/300 | Agg: 234/300\n"]},{"output_type":"stream","name":"stderr","text":[" 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 401/1000 [09:08<11:16, 1.13s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 295/400 | Cl: 293/400 | Agg: 308/400\n"]},{"output_type":"stream","name":"stderr","text":[" 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 501/1000 [11:24<10:34, 1.27s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 362/500 | Cl: 365/500 | Agg: 383/500\n"]},{"output_type":"stream","name":"stderr","text":[" 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 601/1000 [13:37<10:44, 1.61s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 443/600 | Cl: 440/600 | Agg: 462/600\n"]},{"output_type":"stream","name":"stderr","text":[" 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 701/1000 [15:54<06:52, 1.38s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 523/700 | Cl: 517/700 | Agg: 546/700\n"]},{"output_type":"stream","name":"stderr","text":[" 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 801/1000 [18:07<03:42, 1.12s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 601/800 | Cl: 591/800 | Agg: 626/800\n"]},{"output_type":"stream","name":"stderr","text":[" 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 901/1000 [20:24<01:56, 1.18s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","Abs: 670/900 | Cl: 666/900 | Agg: 703/900\n"]},{"output_type":"stream","name":"stderr","text":["100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1000/1000 [22:36<00:00, 1.36s/it]"]},{"output_type":"stream","name":"stdout","text":["\n","\n","Abstract Model Accuracy: 72.89999999999999%\n","Claim Model Accuracy: 72.8%\n","Aggregated Model Accuracy: 76.2%\n"]},{"output_type":"stream","name":"stderr","text":["\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"Enwp7rw___5t"},"execution_count":null,"outputs":[]}]}