diff --git a/demo/kpi_answering/inference_demo/inference_kpi_answering.ipynb b/demo/kpi_answering/inference_demo/inference_kpi_answering.ipynb index 54c3dcd..69efb52 100644 --- a/demo/kpi_answering/inference_demo/inference_kpi_answering.ipynb +++ b/demo/kpi_answering/inference_demo/inference_kpi_answering.ipynb @@ -31,7 +31,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "question = \"How many programming languages does BLOOM support?\"\n", "context = \"BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages.\"" ] @@ -58,10 +57,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForQuestionAnswering, AutoTokenizer" - ] + "source": [] }, { "cell_type": "code", @@ -69,7 +65,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "question = \"How many programming languages does BLOOM support?\"\n", "context = \"BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages.\"" ] diff --git a/demo/kpi_answering/training_demo/training_kpi_answering.ipynb b/demo/kpi_answering/training_demo/training_kpi_answering.ipynb index e7048af..5546f87 100644 --- a/demo/kpi_answering/training_demo/training_kpi_answering.ipynb +++ b/demo/kpi_answering/training_demo/training_kpi_answering.ipynb @@ -1,1633 +1,1640 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "WNlznhds5LGT", - "outputId": "d8fad1c7-0212-40f4-d6ec-e9c9a17d2e4e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.42.4)\n", - "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.20.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n", - "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.5)\n", - "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", - "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", - "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", - "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n", - "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", - "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n", - "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", - "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", - "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", - "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", - "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", - "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.5)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.7.4)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", - "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" - ] - } - ], - "source": [ - "# Transformers installation\n", - "! pip install transformers datasets\n", - "# To install from source instead of the last release, comment the command above and uncomment the following one.\n", - "# ! pip install git+https://github.com/huggingface/transformers.git" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iIK0Hxk05LGz" - }, - "source": [ - "# Question answering" - ] + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "WNlznhds5LGT", + "outputId": "d8fad1c7-0212-40f4-d6ec-e9c9a17d2e4e" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "MvTwuLHoRSMX" - }, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from datasets import Dataset, DatasetDict\n", - "from sklearn.model_selection import train_test_split" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.42.4)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.20.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.5)\n", + "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.5)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.7.4)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" + ] + } + ], + "source": [ + "# Transformers installation\n", + "! pip install transformers datasets\n", + "# To install from source instead of the last release, comment the command above and uncomment the following one.\n", + "# ! pip install git+https://github.com/huggingface/transformers.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iIK0Hxk05LGz" + }, + "source": [ + "# Question answering" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "MvTwuLHoRSMX" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from datasets import Dataset, DatasetDict\n", + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "7e4x6BE3RU18" + }, + "outputs": [], + "source": [ + "df = pd.read_csv(\"/content/output_curator.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "rU7-BLJ4RgdM" + }, + "outputs": [], + "source": [ + "df = df[[\"question\", \"context\", \"answer\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "a8TqSqezSalW", + "outputId": "65bb26de-e417-4a52-a3aa-1439d3d16f6e" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "7e4x6BE3RU18" - }, - "outputs": [], - "source": [ - "df =pd.read_csv(\"/content/output_curator.csv\")" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['question', 'context', 'answer'],\n", + " num_rows: 9\n", + " })\n", + " test: Dataset({\n", + " features: ['question', 'context', 'answer'],\n", + " num_rows: 3\n", + " })\n", + "})\n" + ] + } + ], + "source": [ + "# Split the DataFrame into train and test sets\n", + "train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)\n", + "train_df = train_df.reset_index(drop=True)\n", + "test_df = test_df.reset_index(drop=True)\n", + "\n", + "# Convert pandas DataFrames to Hugging Face Datasets\n", + "train_dataset = Dataset.from_pandas(train_df)\n", + "test_dataset = Dataset.from_pandas(test_df)\n", + "\n", + "# Create a DatasetDict\n", + "data = DatasetDict({\"train\": train_dataset, \"test\": test_dataset})\n", + "\n", + "# Verify the DatasetDict\n", + "print(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "tMbToS4oSnRR", + "outputId": "305c3311-9877-4f21-ce13-8d2f66596766" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "rU7-BLJ4RgdM" - }, - "outputs": [], - "source": [ - "df = df[['question','context','answer']]" + "data": { + "text/plain": [ + "{'question': 'What is the target year for climate commitment?',\n", + " 'context': 'We continue to work towards delivering on our Net Carbon Footprint ambition to cut the intensity of the greenhouse gas emissions of the energy products we sell by about 50% by 2050, and 20% by 2035 compared to our 2016 levels, in step with society as it moves towards meeting the goals of the Paris Agreement. In 2019, we set shorter-term targets for 2021 of 2-3% lower than our 2016 baseline Net Carbon Footprint. In early 2020, we set a Net Carbon Footprint target for 2022 of 3-4% lower than our 2016 baseline. We will continue to evolve our approach over time.',\n", + " 'answer': '2050'}" ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"train\"][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8S6lOMW_5LHi" + }, + "source": [ + "There are several important fields here:\n", + "\n", + "- `answers`: the starting location of the answer token and the answer text.\n", + "- `context`: background information from which the model needs to extract the answer.\n", + "- `question`: the question a model should answer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MhC58izW5LHj" + }, + "source": [ + "## Preprocess" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "F4pTf4Sq5LHl" + }, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "Ud3kncOGWCjs", + "outputId": "a3ae9608-561b-49cf-fb9a-d70064ec9122" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a8TqSqezSalW", - "outputId": "65bb26de-e417-4a52-a3aa-1439d3d16f6e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DatasetDict({\n", - " train: Dataset({\n", - " features: ['question', 'context', 'answer'],\n", - " num_rows: 9\n", - " })\n", - " test: Dataset({\n", - " features: ['question', 'context', 'answer'],\n", - " num_rows: 3\n", - " })\n", - "})\n" - ] - } - ], - "source": [ - "from datasets import Dataset\n", - "\n", - "# Split the DataFrame into train and test sets\n", - "train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)\n", - "train_df = train_df.reset_index(drop=True)\n", - "test_df = test_df.reset_index(drop=True)\n", - "\n", - "# Convert pandas DataFrames to Hugging Face Datasets\n", - "train_dataset = Dataset.from_pandas(train_df)\n", - "test_dataset = Dataset.from_pandas(test_df)\n", - "\n", - "# Create a DatasetDict\n", - "data = DatasetDict({\n", - " 'train': train_dataset,\n", - " 'test': test_dataset\n", - "})\n", - "\n", - "# Verify the DatasetDict\n", - "print(data)" + "data": { + "text/plain": [ + "{'question': 'What is the target year for climate commitment?',\n", + " 'context': 'We continue to work towards delivering on our Net Carbon Footprint ambition to cut the intensity of the greenhouse gas emissions of the energy products we sell by about 50% by 2050, and 20% by 2035 compared to our 2016 levels, in step with society as it moves towards meeting the goals of the Paris Agreement. In 2019, we set shorter-term targets for 2021 of 2-3% lower than our 2016 baseline Net Carbon Footprint. In early 2020, we set a Net Carbon Footprint target for 2022 of 3-4% lower than our 2016 baseline. We will continue to evolve our approach over time.',\n", + " 'answer': '2050'}" ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example = data[\"train\"][0]\n", + "example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255, + "referenced_widgets": [ + "0d2a3442fae34bc89a84a0c00291998a", + "a9437357189448aca39bb9692899ef58", + "f77b204c053348af9fa5102e7c3abbf5", + "88958c8e4f06467daf0fffc119ddf64b", + "9d292585afff4f29b62971c52cd332bc", + "2108e2ee7bd940d197dd72731cfc780b", + "70c8ef3a1abf4c58b8f728a08ca7a3d9", + "47f8e5f5b30545d1a45bf15aea2d0a7d", + "b424b0770d6948e6be0adf6ef9689e75", + "c8094cf32a23462389bdddef420eec6e", + "22a4b0abf5c1495b9bf72fa1e43275f6", + "effaa98a3c5343389900ddf51a36d999", + "c067abfbf57247edba79014eb13dfe7c", + "873f714bf3cd4340a0c15d8eb95db9fc", + "4e91e0b827fa49d2a9ea49870ceb8a83", + "b4688b7021bb4f7fa8e0d8913000de04", + "b9db5c66844d46eb817ae2910d3b12b1", + "40111b17548745bcab69602afaef9a39", + "74733b29596b4d48ada614f57053bf24", + "37e3aaf398464d809f56fac06a72b8f9", + "61764dc8ec9a4d339c6dd873dd7654cd", + "e3195a1f697546a988ea49622772413d" + ] }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tMbToS4oSnRR", - "outputId": "305c3311-9877-4f21-ce13-8d2f66596766" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'question': 'What is the target year for climate commitment?',\n", - " 'context': 'We continue to work towards delivering on our Net Carbon Footprint ambition to cut the intensity of the greenhouse gas emissions of the energy products we sell by about 50% by 2050, and 20% by 2035 compared to our 2016 levels, in step with society as it moves towards meeting the goals of the Paris Agreement. In 2019, we set shorter-term targets for 2021 of 2-3% lower than our 2016 baseline Net Carbon Footprint. In early 2020, we set a Net Carbon Footprint target for 2022 of 3-4% lower than our 2016 baseline. We will continue to evolve our approach over time.',\n", - " 'answer': '2050'}" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data['train'][0]" - ] + "id": "Gd-CEB70UNLv", + "outputId": "6e28bf7f-b5e0-48f5-a059-dacb50e2b35f" + }, + "outputs": [], + "source": [ + "from datasets import DatasetDict, Dataset\n", + "from transformers import AutoTokenizer\n", + "\n", + "\n", + "def preprocess_function(examples):\n", + " questions = examples[\"question\"]\n", + " contexts = examples[\"context\"]\n", + " answers = examples[\"answer\"]\n", + "\n", + " # Tokenize questions and contexts\n", + " tokenized_inputs = tokenizer(\n", + " questions, contexts, max_length=512, truncation=True, padding=\"max_length\"\n", + " )\n", + "\n", + " # Initialize lists to hold start and end positions\n", + " start_positions = []\n", + " end_positions = []\n", + "\n", + " # Loop through each example\n", + " for i in range(len(questions)):\n", + " # Get the answer text\n", + " answer = answers[i]\n", + " answer_start = contexts[i].find(answer)\n", + "\n", + " if answer_start == -1:\n", + " start_positions.append(0)\n", + " end_positions.append(0)\n", + " else:\n", + " start_positions.append(\n", + " tokenizer.encode(\n", + " contexts[i][:answer_start], add_special_tokens=False\n", + " ).__len__()\n", + " )\n", + " end_positions.append(\n", + " tokenizer.encode(\n", + " contexts[i][: answer_start + len(answer)], add_special_tokens=False\n", + " ).__len__()\n", + " - 1\n", + " )\n", + "\n", + " tokenized_inputs.update(\n", + " {\"start_positions\": start_positions, \"end_positions\": end_positions}\n", + " )\n", + "\n", + " return tokenized_inputs\n", + "\n", + "\n", + "# Apply the preprocessing function to the dataset\n", + "processed_datasets = data.map(preprocess_function, batched=True)\n", + "\n", + "# Remove columns that are not needed\n", + "processed_datasets = processed_datasets.remove_columns(\n", + " [\"question\", \"context\", \"answer\"]\n", + ")\n", + "\n", + "# Verify the processed dataset\n", + "print(processed_datasets)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "id": "4TRqiaW_5LHr" + }, + "outputs": [], + "source": [ + "from transformers import DefaultDataCollator\n", + "\n", + "data_collator = DefaultDataCollator()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R--4vo3-5LHs" + }, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 84, + "referenced_widgets": [ + "c447d42a7d4045e1b5a733d2b719436a", + "82676a7475f548f49f5baf2b61a31e28", + "1c83eaddf31442abbe99d121dc64bd1f", + "1bac75e991ef492eb149ab91952499a8", + "1e1a537f5a1f4e17b28dd511051eda98", + "8f4f499c054f4fe7bc2a313d76ab159f", + "cd46e978eda74dc494ccf0fea5d72ed8", + "f44e35808e98440aa6d24761019cac54", + "b09bc3b4fbd543ebbd62ec27afd67125", + "5108bc94c99d4dba99bdd9c45d016d75", + "e4fa22eb557a48509877f74e90b4cb4e" + ] }, + "id": "3xpmsllx5LHs", + "outputId": "cff4c13d-73fc-4f2f-f539-b2bd2b99f568" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "8S6lOMW_5LHi" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c447d42a7d4045e1b5a733d2b719436a", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "There are several important fields here:\n", - "\n", - "- `answers`: the starting location of the answer token and the answer text.\n", - "- `context`: background information from which the model needs to extract the answer.\n", - "- `question`: the question a model should answer." + "text/plain": [ + "model.safetensors: 0%| | 0.00/268M [00:00\n", + " \n", + " \n", + " [3/3 01:43, Epoch 3/3]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
1No log6.079496
2No log6.034035
3No log6.011786

" ], - "source": [ - "example = data['train'][0]\n", - "example" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 255, - "referenced_widgets": [ - "0d2a3442fae34bc89a84a0c00291998a", - "a9437357189448aca39bb9692899ef58", - "f77b204c053348af9fa5102e7c3abbf5", - "88958c8e4f06467daf0fffc119ddf64b", - "9d292585afff4f29b62971c52cd332bc", - "2108e2ee7bd940d197dd72731cfc780b", - "70c8ef3a1abf4c58b8f728a08ca7a3d9", - "47f8e5f5b30545d1a45bf15aea2d0a7d", - "b424b0770d6948e6be0adf6ef9689e75", - "c8094cf32a23462389bdddef420eec6e", - "22a4b0abf5c1495b9bf72fa1e43275f6", - "effaa98a3c5343389900ddf51a36d999", - "c067abfbf57247edba79014eb13dfe7c", - "873f714bf3cd4340a0c15d8eb95db9fc", - "4e91e0b827fa49d2a9ea49870ceb8a83", - "b4688b7021bb4f7fa8e0d8913000de04", - "b9db5c66844d46eb817ae2910d3b12b1", - "40111b17548745bcab69602afaef9a39", - "74733b29596b4d48ada614f57053bf24", - "37e3aaf398464d809f56fac06a72b8f9", - "61764dc8ec9a4d339c6dd873dd7654cd", - "e3195a1f697546a988ea49622772413d" - ] - }, - "id": "Gd-CEB70UNLv", - "outputId": "6e28bf7f-b5e0-48f5-a059-dacb50e2b35f" - }, - "outputs": [], - "source": [ - "from datasets import DatasetDict, Dataset\n", - "from transformers import AutoTokenizer\n", - "\n", - "\n", - "def preprocess_function(examples):\n", - " questions = examples['question']\n", - " contexts = examples['context']\n", - " answers = examples['answer']\n", - "\n", - " # Tokenize questions and contexts\n", - " tokenized_inputs = tokenizer(\n", - " questions,\n", - " contexts,\n", - " max_length=512,\n", - " truncation=True,\n", - " padding='max_length'\n", - " )\n", - "\n", - " # Initialize lists to hold start and end positions\n", - " start_positions = []\n", - " end_positions = []\n", - "\n", - " # Loop through each example\n", - " for i in range(len(questions)):\n", - " # Get the answer text\n", - " answer = answers[i]\n", - " answer_start = contexts[i].find(answer)\n", - "\n", - " if answer_start == -1:\n", - " start_positions.append(0)\n", - " end_positions.append(0)\n", - " else:\n", - " start_positions.append(tokenizer.encode(contexts[i][:answer_start], add_special_tokens=False).__len__())\n", - " end_positions.append(tokenizer.encode(contexts[i][:answer_start + len(answer)], add_special_tokens=False).__len__() - 1)\n", - "\n", - " tokenized_inputs.update({\n", - " 'start_positions': start_positions,\n", - " 'end_positions': end_positions\n", - " })\n", - "\n", - " return tokenized_inputs\n", - "\n", - "# Apply the preprocessing function to the dataset\n", - "processed_datasets = data.map(preprocess_function, batched=True)\n", - "\n", - "# Remove columns that are not needed\n", - "processed_datasets = processed_datasets.remove_columns([\"question\", \"context\", \"answer\"])\n", - "\n", - "# Verify the processed dataset\n", - "print(processed_datasets)\n" + "data": { + "text/plain": [ + "TrainOutput(global_step=3, training_loss=6.016239166259766, metrics={'train_runtime': 156.7107, 'train_samples_per_second': 0.172, 'train_steps_per_second': 0.019, 'total_flos': 3527633700864.0, 'train_loss': 6.016239166259766, 'epoch': 3.0})" ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"my_awesome_qa_model\",\n", + " evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n", + " logging_dir=\"logs\", # Directory for logs\n", + " logging_steps=10, # Log every 10 steps\n", + " learning_rate=2e-5,\n", + " per_device_train_batch_size=16,\n", + " per_device_eval_batch_size=16,\n", + " num_train_epochs=3,\n", + " weight_decay=0.01,\n", + " push_to_hub=False,\n", + ")\n", + "\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=processed_datasets[\"train\"],\n", + " eval_dataset=processed_datasets[\"test\"],\n", + " tokenizer=tokenizer,\n", + " data_collator=data_collator,\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4TEWIQfK5LHy" + }, + "source": [ + "## Evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "eval_result = trainer.evaluate(processed_datasets[\"test\"])\n", + "print(\"Evaluation results:\")\n", + "for key, value in eval_result.items():\n", + " print(f\"{key}: {value}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import Trainer\n", + "import numpy as np\n", + "\n", + "# Predict labels for the evaluation dataset\n", + "predictions = trainer.predict(processed_datasets[\"test\"])\n", + "start_logits = predictions.predictions[0] # Start logits\n", + "end_logits = predictions.predictions[1] # End logits\n", + "\n", + "# Convert logits to start and end positions\n", + "predicted_starts = np.argmax(start_logits, axis=1)\n", + "predicted_ends = np.argmax(end_logits, axis=1)\n", + "\n", + "# Extract true start and end positions from the dataset\n", + "true_starts = np.array(\n", + " [example[\"start_positions\"] for example in processed_datasets[\"test\"]]\n", + ")\n", + "true_ends = np.array(\n", + " [example[\"end_positions\"] for example in processed_datasets[\"test\"]]\n", + ")\n", + "\n", + "# Calculate accuracy (you might want a different metric depending on your needs)\n", + "accuracy = np.mean((predicted_starts == true_starts) & (predicted_ends == true_ends))\n", + "print(\"Accuracy:\", accuracy)\n", + "\n", + "# Print inputs along with predicted and true answer spans\n", + "for i in range(len(processed_datasets[\"test\"])):\n", + " eva_data = processed_datasets[\"test\"][i]\n", + " input_ids = eva_data[\"input_ids\"]\n", + " true_start = true_starts[i]\n", + " true_end = true_ends[i]\n", + " predicted_start = predicted_starts[i]\n", + " predicted_end = predicted_ends[i]\n", + "\n", + " input_text = tokenizer.decode(input_ids, skip_special_tokens=True)\n", + " predicted_answer = tokenizer.convert_tokens_to_string(\n", + " tokenizer.convert_ids_to_tokens(input_ids[predicted_start : predicted_end + 1])\n", + " )\n", + " true_answer = tokenizer.convert_tokens_to_string(\n", + " tokenizer.convert_ids_to_tokens(input_ids[true_start : true_end + 1])\n", + " )\n", + "\n", + " print(f\"Input: {input_text}\")\n", + " print(f\"True Answer: {true_answer}\")\n", + " print(f\"Predicted Answer: {predicted_answer}\")\n", + " print()\n", + "\n", + "# Save the model and tokenizer\n", + "model.save_pretrained(\"my_awesome_qa_model\")\n", + "tokenizer.save_pretrained(\"my_awesome_qa_model\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "0d2a3442fae34bc89a84a0c00291998a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a9437357189448aca39bb9692899ef58", + "IPY_MODEL_f77b204c053348af9fa5102e7c3abbf5", + "IPY_MODEL_88958c8e4f06467daf0fffc119ddf64b" + ], + "layout": "IPY_MODEL_9d292585afff4f29b62971c52cd332bc" + } }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": { - "id": "4TRqiaW_5LHr" - }, - "outputs": [], - "source": [ - "from transformers import DefaultDataCollator\n", - "\n", - "data_collator = DefaultDataCollator()" - ] + "1bac75e991ef492eb149ab91952499a8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5108bc94c99d4dba99bdd9c45d016d75", + "placeholder": "​", + "style": "IPY_MODEL_e4fa22eb557a48509877f74e90b4cb4e", + "value": " 268M/268M [00:01<00:00, 151MB/s]" + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "R--4vo3-5LHs" - }, - "source": [ - "## Train" - ] + "1c83eaddf31442abbe99d121dc64bd1f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f44e35808e98440aa6d24761019cac54", + "max": 267954768, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b09bc3b4fbd543ebbd62ec27afd67125", + "value": 267954768 + } }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 84, - "referenced_widgets": [ - "c447d42a7d4045e1b5a733d2b719436a", - "82676a7475f548f49f5baf2b61a31e28", - "1c83eaddf31442abbe99d121dc64bd1f", - "1bac75e991ef492eb149ab91952499a8", - "1e1a537f5a1f4e17b28dd511051eda98", - "8f4f499c054f4fe7bc2a313d76ab159f", - "cd46e978eda74dc494ccf0fea5d72ed8", - "f44e35808e98440aa6d24761019cac54", - "b09bc3b4fbd543ebbd62ec27afd67125", - "5108bc94c99d4dba99bdd9c45d016d75", - "e4fa22eb557a48509877f74e90b4cb4e" - ] - }, - "id": "3xpmsllx5LHs", - "outputId": "cff4c13d-73fc-4f2f-f539-b2bd2b99f568" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c447d42a7d4045e1b5a733d2b719436a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "model.safetensors: 0%| | 0.00/268M [00:00\n", - " \n", - " \n", - " [3/3 01:43, Epoch 3/3]\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
EpochTraining LossValidation Loss
1No log6.079496
2No log6.034035
3No log6.011786

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "TrainOutput(global_step=3, training_loss=6.016239166259766, metrics={'train_runtime': 156.7107, 'train_samples_per_second': 0.172, 'train_steps_per_second': 0.019, 'total_flos': 3527633700864.0, 'train_loss': 6.016239166259766, 'epoch': 3.0})" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } + "2108e2ee7bd940d197dd72731cfc780b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "22a4b0abf5c1495b9bf72fa1e43275f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "37e3aaf398464d809f56fac06a72b8f9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "40111b17548745bcab69602afaef9a39": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "47f8e5f5b30545d1a45bf15aea2d0a7d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4e91e0b827fa49d2a9ea49870ceb8a83": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_61764dc8ec9a4d339c6dd873dd7654cd", + "placeholder": "​", + "style": "IPY_MODEL_e3195a1f697546a988ea49622772413d", + "value": " 3/3 [00:00<00:00, 37.58 examples/s]" + } + }, + "5108bc94c99d4dba99bdd9c45d016d75": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "61764dc8ec9a4d339c6dd873dd7654cd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "70c8ef3a1abf4c58b8f728a08ca7a3d9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "74733b29596b4d48ada614f57053bf24": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "82676a7475f548f49f5baf2b61a31e28": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8f4f499c054f4fe7bc2a313d76ab159f", + "placeholder": "​", + "style": "IPY_MODEL_cd46e978eda74dc494ccf0fea5d72ed8", + "value": "model.safetensors: 100%" + } + }, + "873f714bf3cd4340a0c15d8eb95db9fc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_74733b29596b4d48ada614f57053bf24", + "max": 3, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_37e3aaf398464d809f56fac06a72b8f9", + "value": 3 + } + }, + "88958c8e4f06467daf0fffc119ddf64b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c8094cf32a23462389bdddef420eec6e", + "placeholder": "​", + "style": "IPY_MODEL_22a4b0abf5c1495b9bf72fa1e43275f6", + "value": " 9/9 [00:00<00:00, 69.91 examples/s]" + } + }, + "8f4f499c054f4fe7bc2a313d76ab159f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9d292585afff4f29b62971c52cd332bc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a9437357189448aca39bb9692899ef58": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2108e2ee7bd940d197dd72731cfc780b", + "placeholder": "​", + "style": "IPY_MODEL_70c8ef3a1abf4c58b8f728a08ca7a3d9", + "value": "Map: 100%" + } + }, + "b09bc3b4fbd543ebbd62ec27afd67125": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b424b0770d6948e6be0adf6ef9689e75": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b4688b7021bb4f7fa8e0d8913000de04": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b9db5c66844d46eb817ae2910d3b12b1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c067abfbf57247edba79014eb13dfe7c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b9db5c66844d46eb817ae2910d3b12b1", + "placeholder": "​", + "style": "IPY_MODEL_40111b17548745bcab69602afaef9a39", + "value": "Map: 100%" + } + }, + "c447d42a7d4045e1b5a733d2b719436a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_82676a7475f548f49f5baf2b61a31e28", + "IPY_MODEL_1c83eaddf31442abbe99d121dc64bd1f", + "IPY_MODEL_1bac75e991ef492eb149ab91952499a8" ], - "source": [ - "from transformers import TrainingArguments\n", - "\n", - "training_args = TrainingArguments(\n", - " output_dir=\"my_awesome_qa_model\",\n", - " evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n", - " logging_dir=\"logs\", # Directory for logs\n", - " logging_steps=10, # Log every 10 steps\n", - " learning_rate=2e-5,\n", - " per_device_train_batch_size=16,\n", - " per_device_eval_batch_size=16,\n", - " num_train_epochs=3,\n", - " weight_decay=0.01,\n", - " push_to_hub=False,\n", - ")\n", - "\n", - "\n", - "trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=processed_datasets[\"train\"],\n", - " eval_dataset=processed_datasets[\"test\"],\n", - " tokenizer=tokenizer,\n", - " data_collator=data_collator,\n", - ")\n", - "\n", - "trainer.train()" - ] + "layout": "IPY_MODEL_1e1a537f5a1f4e17b28dd511051eda98" + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "4TEWIQfK5LHy" - }, - "source": [ - "## Evaluate" - ] + "c8094cf32a23462389bdddef420eec6e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "eval_result = trainer.evaluate(processed_datasets[\"test\"])\n", - "print(\"Evaluation results:\")\n", - "for key, value in eval_result.items():\n", - " print(f\"{key}: {value}\")\n" - ] + "cd46e978eda74dc494ccf0fea5d72ed8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import Trainer\n", - "import numpy as np\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "# Predict labels for the evaluation dataset\n", - "predictions = trainer.predict(processed_datasets[\"test\"])\n", - "start_logits = predictions.predictions[0] # Start logits\n", - "end_logits = predictions.predictions[1] # End logits\n", - "\n", - "# Convert logits to start and end positions\n", - "predicted_starts = np.argmax(start_logits, axis=1)\n", - "predicted_ends = np.argmax(end_logits, axis=1)\n", - "\n", - "# Extract true start and end positions from the dataset\n", - "true_starts = np.array([example[\"start_positions\"] for example in processed_datasets[\"test\"]])\n", - "true_ends = np.array([example[\"end_positions\"] for example in processed_datasets[\"test\"]])\n", - "\n", - "# Calculate accuracy (you might want a different metric depending on your needs)\n", - "accuracy = np.mean((predicted_starts == true_starts) & (predicted_ends == true_ends))\n", - "print(\"Accuracy:\", accuracy)\n", - "\n", - "# Print inputs along with predicted and true answer spans\n", - "for i in range(len(processed_datasets[\"test\"])):\n", - " eva_data = processed_datasets[\"test\"][i]\n", - " input_ids = eva_data[\"input_ids\"]\n", - " true_start = true_starts[i]\n", - " true_end = true_ends[i]\n", - " predicted_start = predicted_starts[i]\n", - " predicted_end = predicted_ends[i]\n", - " \n", - " input_text = tokenizer.decode(input_ids, skip_special_tokens=True)\n", - " predicted_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[predicted_start:predicted_end+1]))\n", - " true_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[true_start:true_end+1]))\n", - " \n", - " print(f\"Input: {input_text}\")\n", - " print(f\"True Answer: {true_answer}\")\n", - " print(f\"Predicted Answer: {predicted_answer}\")\n", - " print()\n", - "\n", - "# Save the model and tokenizer\n", - "model.save_pretrained(\"my_awesome_qa_model\")\n", - "tokenizer.save_pretrained(\"my_awesome_qa_model\")\n" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] + "e3195a1f697546a988ea49622772413d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" + "e4fa22eb557a48509877f74e90b4cb4e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } }, - "language_info": { - "name": "python" + "effaa98a3c5343389900ddf51a36d999": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c067abfbf57247edba79014eb13dfe7c", + "IPY_MODEL_873f714bf3cd4340a0c15d8eb95db9fc", + "IPY_MODEL_4e91e0b827fa49d2a9ea49870ceb8a83" + ], + "layout": "IPY_MODEL_b4688b7021bb4f7fa8e0d8913000de04" + } + }, + "f44e35808e98440aa6d24761019cac54": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "0d2a3442fae34bc89a84a0c00291998a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_a9437357189448aca39bb9692899ef58", - "IPY_MODEL_f77b204c053348af9fa5102e7c3abbf5", - "IPY_MODEL_88958c8e4f06467daf0fffc119ddf64b" - ], - "layout": "IPY_MODEL_9d292585afff4f29b62971c52cd332bc" - } - }, - "1bac75e991ef492eb149ab91952499a8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_5108bc94c99d4dba99bdd9c45d016d75", - "placeholder": "​", - "style": "IPY_MODEL_e4fa22eb557a48509877f74e90b4cb4e", - "value": " 268M/268M [00:01<00:00, 151MB/s]" - } - }, - "1c83eaddf31442abbe99d121dc64bd1f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_f44e35808e98440aa6d24761019cac54", - "max": 267954768, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_b09bc3b4fbd543ebbd62ec27afd67125", - "value": 267954768 - } - }, - "1e1a537f5a1f4e17b28dd511051eda98": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2108e2ee7bd940d197dd72731cfc780b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "22a4b0abf5c1495b9bf72fa1e43275f6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "37e3aaf398464d809f56fac06a72b8f9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "40111b17548745bcab69602afaef9a39": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "47f8e5f5b30545d1a45bf15aea2d0a7d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4e91e0b827fa49d2a9ea49870ceb8a83": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_61764dc8ec9a4d339c6dd873dd7654cd", - "placeholder": "​", - "style": "IPY_MODEL_e3195a1f697546a988ea49622772413d", - "value": " 3/3 [00:00<00:00, 37.58 examples/s]" - } - }, - "5108bc94c99d4dba99bdd9c45d016d75": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "61764dc8ec9a4d339c6dd873dd7654cd": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "70c8ef3a1abf4c58b8f728a08ca7a3d9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "74733b29596b4d48ada614f57053bf24": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "82676a7475f548f49f5baf2b61a31e28": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_8f4f499c054f4fe7bc2a313d76ab159f", - "placeholder": "​", - "style": "IPY_MODEL_cd46e978eda74dc494ccf0fea5d72ed8", - "value": "model.safetensors: 100%" - } - }, - "873f714bf3cd4340a0c15d8eb95db9fc": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_74733b29596b4d48ada614f57053bf24", - "max": 3, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_37e3aaf398464d809f56fac06a72b8f9", - "value": 3 - } - }, - "88958c8e4f06467daf0fffc119ddf64b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_c8094cf32a23462389bdddef420eec6e", - "placeholder": "​", - "style": "IPY_MODEL_22a4b0abf5c1495b9bf72fa1e43275f6", - "value": " 9/9 [00:00<00:00, 69.91 examples/s]" - } - }, - "8f4f499c054f4fe7bc2a313d76ab159f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "9d292585afff4f29b62971c52cd332bc": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a9437357189448aca39bb9692899ef58": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2108e2ee7bd940d197dd72731cfc780b", - "placeholder": "​", - "style": "IPY_MODEL_70c8ef3a1abf4c58b8f728a08ca7a3d9", - "value": "Map: 100%" - } - }, - "b09bc3b4fbd543ebbd62ec27afd67125": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "b424b0770d6948e6be0adf6ef9689e75": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "b4688b7021bb4f7fa8e0d8913000de04": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b9db5c66844d46eb817ae2910d3b12b1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c067abfbf57247edba79014eb13dfe7c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b9db5c66844d46eb817ae2910d3b12b1", - "placeholder": "​", - "style": "IPY_MODEL_40111b17548745bcab69602afaef9a39", - "value": "Map: 100%" - } - }, - "c447d42a7d4045e1b5a733d2b719436a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_82676a7475f548f49f5baf2b61a31e28", - "IPY_MODEL_1c83eaddf31442abbe99d121dc64bd1f", - "IPY_MODEL_1bac75e991ef492eb149ab91952499a8" - ], - "layout": "IPY_MODEL_1e1a537f5a1f4e17b28dd511051eda98" - } - }, - "c8094cf32a23462389bdddef420eec6e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "cd46e978eda74dc494ccf0fea5d72ed8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e3195a1f697546a988ea49622772413d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e4fa22eb557a48509877f74e90b4cb4e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "effaa98a3c5343389900ddf51a36d999": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_c067abfbf57247edba79014eb13dfe7c", - "IPY_MODEL_873f714bf3cd4340a0c15d8eb95db9fc", - "IPY_MODEL_4e91e0b827fa49d2a9ea49870ceb8a83" - ], - "layout": "IPY_MODEL_b4688b7021bb4f7fa8e0d8913000de04" - } - }, - "f44e35808e98440aa6d24761019cac54": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f77b204c053348af9fa5102e7c3abbf5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_47f8e5f5b30545d1a45bf15aea2d0a7d", - "max": 9, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_b424b0770d6948e6be0adf6ef9689e75", - "value": 9 - } - } - } + "f77b204c053348af9fa5102e7c3abbf5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_47f8e5f5b30545d1a45bf15aea2d0a7d", + "max": 9, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b424b0770d6948e6be0adf6ef9689e75", + "value": 9 + } } - }, - "nbformat": 4, - "nbformat_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/demo/relevance_detector/training_demo/train_sentence_transformer.ipynb b/demo/relevance_detector/training_demo/train_sentence_transformer.ipynb index f8bdc27..7178e93 100644 --- a/demo/relevance_detector/training_demo/train_sentence_transformer.ipynb +++ b/demo/relevance_detector/training_demo/train_sentence_transformer.ipynb @@ -100,7 +100,12 @@ " label = self.labels[idx]\n", "\n", " inputs = self.tokenizer(\n", - " question, context, truncation=True, padding=\"max_length\", max_length=self.max_length, return_tensors=\"pt\"\n", + " question,\n", + " context,\n", + " truncation=True,\n", + " padding=\"max_length\",\n", + " max_length=self.max_length,\n", + " return_tensors=\"pt\",\n", " )\n", "\n", " input_ids = inputs[\"input_ids\"].squeeze()\n", @@ -137,7 +142,9 @@ "source": [ "MODEL_NAME = \"sentence-transformers/all-MiniLM-L6-v2\"\n", "NUM_LABELS = 2\n", - "model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)\n", + "model = AutoModelForSequenceClassification.from_pretrained(\n", + " MODEL_NAME, num_labels=NUM_LABELS\n", + ")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)" ] @@ -161,10 +168,14 @@ "MAX_LENGTH = 512\n", "\n", "# Create training dataset\n", - "train_dataset = CustomDataset(tokenizer, train_df[\"question\"], train_df[\"context\"], train_df[\"label\"], MAX_LENGTH)\n", + "train_dataset = CustomDataset(\n", + " tokenizer, train_df[\"question\"], train_df[\"context\"], train_df[\"label\"], MAX_LENGTH\n", + ")\n", "\n", "# Create evaluation dataset\n", - "eval_dataset = CustomDataset(tokenizer, eval_df[\"question\"], eval_df[\"context\"], eval_df[\"label\"], MAX_LENGTH)" + "eval_dataset = CustomDataset(\n", + " tokenizer, eval_df[\"question\"], eval_df[\"context\"], eval_df[\"label\"], MAX_LENGTH\n", + ")" ] }, {