{ "cells": [ { "cell_type": "markdown", "id": "942fa22a-c776-4a44-bde9-75b7cb4202ba", "metadata": {}, "source": [ "## Outline\n", "\n", "1. We collect a dataset consisting of (user_question, answer_context, dialogue_history -> answer)\n", "2. We duplicate a small portion of dataset, where we remove answer_context\n", "2. We augment 'answer_context' with (non_answer) picked by a reasonably-performing QA system: variable ordering, consistent number of answers\n", "3. We train the model for exact-match generation \n", "- Also evaluate the exact-match ratio\n", "- Separately evaluate with full-context questions" ] }, { "cell_type": "markdown", "id": "766c4c50-6e72-41b2-b6d7-1e4c3c309a68", "metadata": {}, "source": [ "### 1. Positive contexts collection" ] }, { "cell_type": "code", "execution_count": 1, "id": "33d57a85-c079-4cf1-b9ad-3b00ce916720", "metadata": {}, "outputs": [], "source": [ "import datasets" ] }, { "cell_type": "code", "execution_count": 2, "id": "0434a258-27ca-4cec-bb85-60673fea2b16", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default-8d557d41fc795903\n", "Found cached dataset json (/home/xstefan3/.cache/huggingface/datasets/json/default-8d557d41fc795903/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "366db7856ce341a6854a08c244aa5db1", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
wikipedia_page_titlebackgroundsection_titlecontextturn_idsquestionsfollowupsyesnosanswersorig_answers
dialogue_id
C_69758fcdfc1f46baba0e92c0f3b0919c_1MalayaliThe Malayali people or Keralite people (also s...Geographic distribution and populationAccording to the Indian census of 2001, there ...[C_69758fcdfc1f46baba0e92c0f3b0919c_1_q#0, C_6...[Where is Malayali located?, What other langua...[2, 1, 1, 1, 1, 1, 1][2, 2, 2, 2, 2, 0, 2]{'texts': [['30,803,747 speakers of Malayalam ...{'texts': ['30,803,747 speakers of Malayalam i...
C_69758fcdfc1f46baba0e92c0f3b0919c_0MalayaliThe Malayali people or Keralite people (also s...Language and literatureMalayalam is the language spoken by the Malaya...[C_69758fcdfc1f46baba0e92c0f3b0919c_0_q#0, C_6...[what language do they speak?, Do they speak a...[0, 0, 0, 0, 0, 0, 0][2, 2, 2, 2, 2, 2, 2]{'texts': [['Malayalam is the language spoken ...{'texts': ['Malayalam is the language spoken b...
\n", "" ], "text/plain": [ " wikipedia_page_title \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 Malayali \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 Malayali \n", "\n", " background \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 The Malayali people or Keralite people (also s... \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 The Malayali people or Keralite people (also s... \n", "\n", " section_title \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 Geographic distribution and population \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 Language and literature \n", "\n", " context \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 According to the Indian census of 2001, there ... \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 Malayalam is the language spoken by the Malaya... \n", "\n", " turn_ids \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [C_69758fcdfc1f46baba0e92c0f3b0919c_1_q#0, C_6... \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [C_69758fcdfc1f46baba0e92c0f3b0919c_0_q#0, C_6... \n", "\n", " questions \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [Where is Malayali located?, What other langua... \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [what language do they speak?, Do they speak a... \n", "\n", " followups \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [2, 1, 1, 1, 1, 1, 1] \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [0, 0, 0, 0, 0, 0, 0] \n", "\n", " yesnos \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 [2, 2, 2, 2, 2, 0, 2] \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 [2, 2, 2, 2, 2, 2, 2] \n", "\n", " answers \\\n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 {'texts': [['30,803,747 speakers of Malayalam ... \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 {'texts': [['Malayalam is the language spoken ... \n", "\n", " orig_answers \n", "dialogue_id \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_1 {'texts': ['30,803,747 speakers of Malayalam i... \n", "C_69758fcdfc1f46baba0e92c0f3b0919c_0 {'texts': ['Malayalam is the language spoken b... " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "quac_train_df = quac_train.to_pandas().set_index(\"dialogue_id\", drop=True)\n", "quac_train_df.head(2)" ] }, { "cell_type": "code", "execution_count": 7, "id": "01b2994d-3ba3-4ce0-9531-20e1858ee878", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([array(['what team did unitas play for',\n", " 'how many games did the colts win',\n", " 'who did they play in the playoffs', 'did they win the super bowl',\n", " 'who did they play in the super bowl', 'what were unitas stats'],\n", " dtype=object) ,\n", " {'texts': array([array(['The Colts'], dtype=object),\n", " array(['the Colts ran off 10 straight victories to finish with a 12-2 record.'],\n", " dtype=object) ,\n", " array(['Cleveland Browns'], dtype=object),\n", " array(['losing 27-0.'], dtype=object),\n", " array(['the Packers.'], dtype=object),\n", " array(['Gary Cuozzo also suffered a season-ending injury the following'],\n", " dtype=object) ],\n", " dtype=object), 'answer_starts': array([array([920], dtype=int32), array([142], dtype=int32),\n", " array([552], dtype=int32), array([604], dtype=int32),\n", " array([1487], dtype=int32), array([1292], dtype=int32)],\n", " dtype=object)} ],\n", " dtype=object)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "quac_train_df.loc['C_2ba58216460d43aa986fc0e897537239_0'][[\"questions\", \"answers\"]].values" ] }, { "cell_type": "code", "execution_count": 8, "id": "e680d1ec-ca04-452c-ad44-eae3b43559cc", "metadata": {}, "outputs": [], "source": [ "def answer_for_question(questions: dict, answers: list, question: str) -> str:\n", " answers = [anss[0] for anss in answers[\"texts\"]]\n", " # print(questions)\n", " # print(question)\n", " assert question in questions\n", " assert len(answers) == len(questions)\n", " \n", " return next(a for i, a in enumerate(answers) if questions[i] == question)" ] }, { "cell_type": "code", "execution_count": 9, "id": "9a58cc80-a4b8-4b59-b5c3-e8b095e2c281", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "786095c0ea3e432dac7dd1912cb3832d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/31526 [00:00= num_responses:\n", " break\n", "\n", " return unique_responses" ] }, { "cell_type": "code", "execution_count": 19, "id": "f3587fed-f244-45f0-af28-6e5e36ac15b9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5318a11fc3484b188ef6c15a6b952d95", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/39407 [00:00