{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## NUBIA: A SoTA evaluation metric for text generation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import and initialize the Nubia class from nubia.py (wherever you cloned the repo)\n", "See our [colab notebook](https://colab.research.google.com/drive/1_K8pOB8fRRnkBPwlcmvUNHgCr4ur8rFg) as an example\n", "\n", "Note: The first time you initialize the class it will download the pretrained models from the S3 bucket, this could take a while depending on your internet connection." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from nubia import Nubia" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loading archive file pretrained/roBERTa_STS\n", "| [input] dictionary: 50265 types\n", "loading archive file pretrained/roBERTa_MNLI\n", "| dictionary: 50264 types\n" ] } ], "source": [ "nubia = Nubia()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### You're now ready to start evaluating! \n", "\n", "`nubia.score` takes 7 parameters: `(ref, hyp, verbose=False, get_features=False, six_dim=False, aggregator=\"agg_one\")`\n", "\n", "`ref` and `hyp` are the strings nubia will compare. \n", "\n", "Setting `get_features` to `True` will return a dictionary with additional features (semantic relation, contradiction, irrelevancy, logical agreement, and grammaticality) aside from the nubia score. `Verbose=True` prints all the features.\n", "\n", "`six_dim = True` will use a six dimensional \n", "\n", "`aggregator` is set to `agg_one` by default, but you may choose to try `agg_two` which is" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "0.9905961979783401" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nubia.score(\"The dinner was delicious.\", \"It was a tasty dinner.\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.601254306957781" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nubia.score(\"The dinner was delicious.\", \"The dinner did not taste good.\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Semantic relation: 1.4594818353652954/5.0\n", "Percent chance of contradiction: 99.90345239639282%\n", "Percent chance of irrelevancy or new information: 0.06429857457987964%\n", "Percent chance of logical agreement: 0.03225349937565625%\n", "\n", "NUBIA score: 0.601254306957781/1.0\n" ] }, { "data": { "text/plain": [ "{'nubia_score': 0.601254306957781,\n", " 'features': {'semantic_relation': 1.4594818353652954,\n", " 'contradiction': 99.90345239639282,\n", " 'irrelevancy': 0.06429857457987964,\n", " 'logical_agreement': 0.03225349937565625,\n", " 'grammar_ref': 5.1724853515625,\n", " 'grammar_hyp': 4.905452728271484}}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nubia.score(\"The dinner was delicious.\", \"The dinner did not taste good.\", verbose=True, get_features=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }