{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "DGtuJlOHoeR7", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DGtuJlOHoeR7", "outputId": "a428f842-645e-486e-a359-51c99c5cb1bf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting transformers\n", " Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)\n", "\u001b[K |████████████████████████████████| 5.5 MB 9.1 MB/s \n", "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.8.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2022.6.2)\n", "Collecting huggingface-hub<1.0,>=0.10.0\n", " Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)\n", "\u001b[K |████████████████████████████████| 182 kB 62.4 MB/s \n", "\u001b[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", " Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n", "\u001b[K |████████████████████████████████| 7.6 MB 52.0 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n", "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.13.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.1.1)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.10.0)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.9.24)\n", "Installing collected packages: tokenizers, huggingface-hub, transformers\n", "Successfully installed huggingface-hub-0.11.0 tokenizers-0.13.2 transformers-4.24.0\n" ] } ], "source": [ "!pip install transformers" ] }, { "cell_type": "code", "execution_count": null, "id": "5AmRPydso3Id", "metadata": { "id": "5AmRPydso3Id" }, "outputs": [], "source": [ "from transformers import AutoTokenizer, TFAutoModel" ] }, { "cell_type": "code", "execution_count": null, "id": "67SqVHrro3Wd", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 193, "referenced_widgets": [ "f07b2752f6394b638f441abb111e93e6", "25cc16b8d2c04213a33aa7b7fc51c02e", "872fd9f0c86d4f7399b24b6234d30326", "ab98d6bdc3c44097837b57d68ed89f84", "a8abbd6c0aa642ea971de0354445d0f6", "b7e2ed5256384def91cdf5a687a7459d", "b55aec6e6a554532954d751f5af6d917", "98950ddfbd4e48eca132bd709cee6762", "fa85e893966046399e87cc6b292eb47d", "9626e8c5ef854eb68afc9b627ed10dfc", "8c8cd3d51f9d4d28ad2eafc2abec6576", "6f8fb8284fc649118c0ba53d049e5c56", "32ba13799e5140ef9c58149b40d86c2f", "423fe2f31065474c992d39f158757e8a", "bfd1e9a341e844729652573bee473b90", "d2c94f525d9f4537bb86209aeacbb713", "f07e4555320b4d9d9e4ef8d66b15d05f", "744aaa73d0fa473da416e443e9ec2084", "61ebaa62bb7e42098f2ea3cd95b2977e", "94213341c0e14639bea3f607bfe4ddd9", "3727962fa5b9484bb10504c2c59419f2", "2f5a7c7b350a41b5826ac3b1ced445b2" ] }, "id": "67SqVHrro3Wd", "outputId": "171a3629-3b0f-46ff-bf09-76112afd1e86" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f07b2752f6394b638f441abb111e93e6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/385 [00:00" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clin_bert" ] }, { "cell_type": "code", "execution_count": null, "id": "63f425f7", "metadata": { "id": "63f425f7" }, "outputs": [], "source": [ "df = pd.read_csv('case_training.csv')\n", "df = df.drop(df.columns[0], axis = 1)" ] }, { "cell_type": "code", "execution_count": null, "id": "26bafe22", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "26bafe22", "outputId": "cb678f31-1c47-4d65-b89c-29d1c27e8cfe", "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeweightMaleSchedSkinMinsscheduled_case_durationABRAMS, REID ALLENHENTZEN, ERIC RICHARDLEEK, BRYAN TERRYMEUNIER, MATTHEW JOHNRECHNIC, MARK...withoutwoodworkworkstationworseworstwouldwristbert_textactual_case_duration
0422220.809011500010...00000002EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...69
1383035.20909000010...00000002EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...110
2613556.81659500100...00000003Narrative & Impression EXAM DESCRIPTION: X-RAY...103
3572681.60606000010...00000002EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...70
4552625.6012012000010...00000103EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...90
..................................................................
140403488.01609000010...00000003EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU...76
141423244.8112012010000...00000000EXAM DESCRIPTION: X-RAY ELBOW 2 VIEWS - LEFT ...162
142413577.61858510000...00000003EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...111
143642281.6012012000010...00000002EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU...102
144512368.0015017510000...10000003EXAM DESCRIPTION: CT RT UPPER EXTREMITY CLINI...233
\n", "

145 rows × 1351 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " Age weight Male SchedSkinMins scheduled_case_duration \\\n", "0 42 2220.8 0 90 115 \n", "1 38 3035.2 0 90 90 \n", "2 61 3556.8 1 65 95 \n", "3 57 2681.6 0 60 60 \n", "4 55 2625.6 0 120 120 \n", ".. ... ... ... ... ... \n", "140 40 3488.0 1 60 90 \n", "141 42 3244.8 1 120 120 \n", "142 41 3577.6 1 85 85 \n", "143 64 2281.6 0 120 120 \n", "144 51 2368.0 0 150 175 \n", "\n", " ABRAMS, REID ALLEN HENTZEN, ERIC RICHARD LEEK, BRYAN TERRY \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 0 0 \n", ".. ... ... ... \n", "140 0 0 0 \n", "141 1 0 0 \n", "142 1 0 0 \n", "143 0 0 0 \n", "144 1 0 0 \n", "\n", " MEUNIER, MATTHEW JOHN RECHNIC, MARK ... without wood work \\\n", "0 1 0 ... 0 0 0 \n", "1 1 0 ... 0 0 0 \n", "2 0 0 ... 0 0 0 \n", "3 1 0 ... 0 0 0 \n", "4 1 0 ... 0 0 0 \n", ".. ... ... ... ... ... ... \n", "140 1 0 ... 0 0 0 \n", "141 0 0 ... 0 0 0 \n", "142 0 0 ... 0 0 0 \n", "143 1 0 ... 0 0 0 \n", "144 0 0 ... 1 0 0 \n", "\n", " workstation worse worst would wrist \\\n", "0 0 0 0 0 2 \n", "1 0 0 0 0 2 \n", "2 0 0 0 0 3 \n", "3 0 0 0 0 2 \n", "4 0 0 1 0 3 \n", ".. ... ... ... ... ... \n", "140 0 0 0 0 3 \n", "141 0 0 0 0 0 \n", "142 0 0 0 0 3 \n", "143 0 0 0 0 2 \n", "144 0 0 0 0 3 \n", "\n", " bert_text actual_case_duration \n", "0 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 69 \n", "1 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 110 \n", "2 Narrative & Impression EXAM DESCRIPTION: X-RAY... 103 \n", "3 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 70 \n", "4 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 90 \n", ".. ... ... \n", "140 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU... 76 \n", "141 EXAM DESCRIPTION: X-RAY ELBOW 2 VIEWS - LEFT ... 162 \n", "142 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 111 \n", "143 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU... 102 \n", "144 EXAM DESCRIPTION: CT RT UPPER EXTREMITY CLINI... 233 \n", "\n", "[145 rows x 1351 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": null, "id": "6196d9c1", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6196d9c1", "outputId": "89307b3c-8c75-4c1b-8dff-97749c543d26" }, "outputs": [ { "data": { "text/plain": [ "['Age',\n", " 'weight',\n", " 'Male',\n", " 'SchedSkinMins',\n", " 'scheduled_case_duration',\n", " 'ABRAMS, REID ALLEN',\n", " 'HENTZEN, ERIC RICHARD',\n", " 'LEEK, BRYAN TERRY',\n", " 'MEUNIER, MATTHEW JOHN',\n", " 'RECHNIC, MARK',\n", " 'Healthy',\n", " 'Mild Systemic Disease',\n", " 'Severe Systemic Disease',\n", " 'Choice Per Patient on Day of Surgery',\n", " 'General',\n", " 'Monitored Anesthesia Care (MAC) ',\n", " 'Regional',\n", " 'height',\n", " '00',\n", " '01',\n", " '02',\n", " '03',\n", " '04',\n", " '05',\n", " '06',\n", " '07',\n", " '08',\n", " '09',\n", " '10',\n", " '104',\n", " '105',\n", " '108',\n", " '11',\n", " '116',\n", " '12',\n", " '1204',\n", " '121',\n", " '13',\n", " '1354',\n", " '138',\n", " '14',\n", " '140',\n", " '15',\n", " '16',\n", " '160',\n", " '1640',\n", " '17',\n", " '174',\n", " '18',\n", " '1814',\n", " '188',\n", " '19',\n", " '191',\n", " '1930',\n", " '196',\n", " '1st',\n", " '20',\n", " '2001',\n", " '2015',\n", " '2017',\n", " '2018',\n", " '2019',\n", " '2020',\n", " '21',\n", " '211',\n", " '2139',\n", " '22',\n", " '223',\n", " '2236',\n", " '228',\n", " '23',\n", " '2306',\n", " '24',\n", " '242',\n", " '244',\n", " '249',\n", " '25',\n", " '26',\n", " '262',\n", " '27',\n", " '270',\n", " '28',\n", " '284',\n", " '29',\n", " '297',\n", " '2nd',\n", " '30',\n", " '306',\n", " '307',\n", " '31',\n", " '32',\n", " '33',\n", " '330',\n", " '332',\n", " '34',\n", " '346',\n", " '35',\n", " '36',\n", " '360',\n", " '363',\n", " '37',\n", " '38',\n", " '383',\n", " '39',\n", " '391',\n", " '3d',\n", " '3rd',\n", " '40',\n", " '41',\n", " '416',\n", " '42',\n", " '43',\n", " '44',\n", " '45',\n", " '46',\n", " '467',\n", " '47',\n", " '48',\n", " '49',\n", " '50',\n", " '51',\n", " '52',\n", " '53',\n", " '54',\n", " '55',\n", " '56',\n", " '57',\n", " '574',\n", " '58',\n", " '59',\n", " '5th',\n", " '60',\n", " '61',\n", " '62',\n", " '64',\n", " '72',\n", " '74',\n", " '76',\n", " '80',\n", " '82',\n", " '84',\n", " '841',\n", " '90',\n", " '961',\n", " 'abnormalities',\n", " 'abnormality',\n", " 'abut',\n", " 'abutment',\n", " 'accession',\n", " 'accident',\n", " 'accompanying',\n", " 'accounting',\n", " 'accounts',\n", " 'acquired',\n", " 'acromioclavicular',\n", " 'across',\n", " 'acute',\n", " 'adam',\n", " 'addition',\n", " 'additional',\n", " 'additionally',\n", " 'adjacent',\n", " 'administered',\n", " 'administration',\n", " 'advanced',\n", " 'age',\n", " 'ago',\n", " 'agree',\n", " 'ailable',\n", " 'alex',\n", " 'algorithm',\n", " 'algorithms',\n", " 'aligned',\n", " 'alignment',\n", " 'allowing',\n", " 'almost',\n", " 'along',\n", " 'already',\n", " 'also',\n", " 'although',\n", " 'amilcare',\n", " 'amount',\n", " 'anatomic',\n", " 'anchor',\n", " 'anchors',\n", " 'angle',\n", " 'angular',\n", " 'angulated',\n", " 'angulation',\n", " 'ankle',\n", " 'antecubital',\n", " 'anterior',\n", " 'anteriorly',\n", " 'anterolateral',\n", " 'anterolaterally',\n", " 'anteromedial',\n", " 'anteromedially',\n", " 'ap',\n", " 'apex',\n", " 'apparent',\n", " 'appear',\n", " 'appearance',\n", " 'appears',\n", " 'appreciated',\n", " 'approaches',\n", " 'appropriate',\n", " 'approximate',\n", " 'approximately',\n", " 'arbelo',\n", " 'arcs',\n", " 'area',\n", " 'areas',\n", " 'arising',\n", " 'around',\n", " 'art',\n", " 'arthrodesis',\n", " 'arthroplasty',\n", " 'arthrosis',\n", " 'articular',\n", " 'articulate',\n", " 'articulated',\n", " 'articulating',\n", " 'articulation',\n", " 'articulations',\n", " 'artifact',\n", " 'artifactual',\n", " 'aspect',\n", " 'aspects',\n", " 'aspiration',\n", " 'assess',\n", " 'assessed',\n", " 'assessing',\n", " 'assessment',\n", " 'associated',\n", " 'assure',\n", " 'attachment',\n", " 'attempted',\n", " 'attention',\n", " 'automatic',\n", " 'av',\n", " 'available',\n", " 'avascular',\n", " 'avulsion',\n", " 'axial',\n", " 'axis',\n", " 'baldassarre',\n", " 'bandage',\n", " 'basal',\n", " 'base',\n", " 'basis',\n", " 'bed',\n", " 'bending',\n", " 'benefit',\n", " 'best',\n", " 'better',\n", " 'bike',\n", " 'bilateral',\n", " 'bilaterally',\n", " 'bivalve',\n", " 'blood',\n", " 'bodies',\n", " 'body',\n", " 'bold',\n", " 'bone',\n", " 'bones',\n", " 'bony',\n", " 'borderline',\n", " 'boss',\n", " 'bradley',\n", " 'brady',\n", " 'break',\n", " 'bridging',\n", " 'brogan',\n", " 'bubbles',\n", " 'bultman',\n", " 'c7',\n", " 'calcific',\n", " 'calcifications',\n", " 'calcium',\n", " 'callus',\n", " 'cannot',\n", " 'capitate',\n", " 'capitellar',\n", " 'capitellum',\n", " 'capitolunate',\n", " 'capsule',\n", " 'care',\n", " 'carpal',\n", " 'carpi',\n", " 'carpometacarpal',\n", " 'carpus',\n", " 'cast',\n", " 'casting',\n", " 'centered',\n", " 'central',\n", " 'centrally',\n", " 'change',\n", " 'changed',\n", " 'changes',\n", " 'channel',\n", " 'characterized',\n", " 'chen',\n", " 'cheng',\n", " 'chip',\n", " 'chondrocalcinosis',\n", " 'christine',\n", " 'christopher',\n", " 'chronic',\n", " 'chung',\n", " 'circumferential',\n", " 'clavicle',\n", " 'clearly',\n", " 'clinical',\n", " 'close',\n", " 'closed',\n", " 'closely',\n", " 'cm',\n", " 'cmc',\n", " 'cmcj',\n", " 'coalition',\n", " 'collapse',\n", " 'collateral',\n", " 'collection',\n", " 'collections',\n", " 'colles',\n", " 'collision',\n", " 'columnar',\n", " 'combined',\n", " 'comminuted',\n", " 'comminution',\n", " 'common',\n", " 'communicated',\n", " 'compare',\n", " 'compared',\n", " 'comparison',\n", " 'compartment',\n", " 'compartments',\n", " 'compatible',\n", " 'complete',\n", " 'complication',\n", " 'complications',\n", " 'component',\n", " 'concentric',\n", " 'concurrent',\n", " 'configuration',\n", " 'confirm',\n", " 'congruent',\n", " 'consequently',\n", " 'consider',\n", " 'considerable',\n", " 'considerably',\n", " 'consideration',\n", " 'consistent',\n", " 'conspicuity',\n", " 'conspicuous',\n", " 'contemplated',\n", " 'contralateral',\n", " 'contrast',\n", " 'control',\n", " 'conventional',\n", " 'coracoclavicular',\n", " 'coronal',\n", " 'coronally',\n", " 'coronoid',\n", " 'corpus',\n", " 'corrected',\n", " 'correction',\n", " 'correlate',\n", " 'corresponding',\n", " 'cortex',\n", " 'cortical',\n", " 'corticated',\n", " 'cotterill',\n", " 'cottrell',\n", " 'created',\n", " 'critical',\n", " 'cross',\n", " 'cruz',\n", " 'ct',\n", " 'ctdi',\n", " 'ctdivol',\n", " 'ctdlvol',\n", " 'ctrm',\n", " 'current',\n", " 'currently',\n", " 'cyst',\n", " 'cystic',\n", " 'cysts',\n", " 'dark',\n", " 'data',\n", " 'date',\n", " 'dated',\n", " 'david',\n", " 'day',\n", " 'days',\n", " 'decrease',\n", " 'decreased',\n", " 'dedicated',\n", " 'deep',\n", " 'definite',\n", " 'definitely',\n", " 'definitive',\n", " 'deformity',\n", " 'degenerative',\n", " 'degree',\n", " 'degrees',\n", " 'delasotta',\n", " 'delayed',\n", " 'delineated',\n", " 'demineralization',\n", " 'demineralized',\n", " 'demonstrate',\n", " 'demonstrated',\n", " 'demonstrates',\n", " 'demonstrating',\n", " 'demonstration',\n", " 'densities',\n", " 'density',\n", " 'deposition',\n", " 'depressed',\n", " 'depression',\n", " 'described',\n", " 'description',\n", " 'detail',\n", " 'detailed',\n", " 'details',\n", " 'diameter',\n", " 'diaphyseal',\n", " 'diaphysis',\n", " 'diastasis',\n", " 'diego',\n", " 'differences',\n", " 'different',\n", " 'difficult',\n", " 'diffuse',\n", " 'diffusely',\n", " 'digitorum',\n", " 'dimension',\n", " 'directed',\n", " 'discontinuity',\n", " 'disease',\n", " 'disi',\n", " 'dislocated',\n", " 'dislocation',\n", " 'dispersed',\n", " 'displaced',\n", " 'displacement',\n", " 'distal',\n", " 'distally',\n", " 'distance',\n", " 'distances',\n", " 'distracted',\n", " 'distraction',\n", " 'disuse',\n", " 'dlp',\n", " 'doi',\n", " 'dominant',\n", " 'donated',\n", " 'done',\n", " 'dorsal',\n", " 'dorsally',\n", " 'dorsolateral',\n", " 'dorsum',\n", " 'dose',\n", " 'drift',\n", " 'due',\n", " 'dynamic',\n", " 'earlier',\n", " 'early',\n", " 'ecchymosis',\n", " 'edema',\n", " 'edward',\n", " 'effect',\n", " 'effusion',\n", " 'either',\n", " 'elbow',\n", " 'elsewhere',\n", " 'employ',\n", " 'employed',\n", " 'encounter',\n", " 'end',\n", " 'enhancement',\n", " 'enter',\n", " 'entering',\n", " 'enthesopathic',\n", " 'enthesopathy',\n", " 'entire',\n", " 'entrapment',\n", " 'epic',\n", " 'epicondyle',\n", " 'epiphysis',\n", " 'equipment',\n", " 'eric',\n", " 'especially',\n", " 'essentially',\n", " 'establish',\n", " 'evaluate',\n", " 'evaluated',\n", " 'evaluation',\n", " 'evelyn',\n", " 'evidence',\n", " 'evident',\n", " 'exam',\n", " 'examination',\n", " 'examinations',\n", " 'exams',\n", " 'exhibiting',\n", " 'exposes',\n", " 'exposure',\n", " 'extend',\n", " 'extended',\n", " 'extending',\n", " 'extends',\n", " 'extension',\n", " 'extensive',\n", " 'extensor',\n", " 'external',\n", " 'extra',\n", " 'extremity',\n", " 'facet',\n", " 'failure',\n", " 'fall',\n", " 'fat',\n", " 'fdp',\n", " 'features',\n", " 'fell',\n", " 'fellow',\n", " 'femoral',\n", " 'fiberglass',\n", " 'fibrocartilage',\n", " 'fifth',\n", " 'films',\n", " 'finding',\n", " 'findings',\n", " 'fine',\n", " 'finger',\n", " 'fingers',\n", " 'first',\n", " 'five',\n", " 'fixating',\n", " 'fixation',\n", " 'fixator',\n", " 'flake',\n", " 'flexed',\n", " 'flexion',\n", " 'flexor',\n", " 'fliszar',\n", " 'fluid',\n", " 'fluoroscopy',\n", " 'focal',\n", " 'foci',\n", " 'follow',\n", " 'following',\n", " 'followup',\n", " 'forearm',\n", " 'foreign',\n", " 'foreshortening',\n", " 'formation',\n", " 'fossa',\n", " 'fossae',\n", " 'four',\n", " 'fourth',\n", " 'fpl',\n", " 'fracture',\n", " 'fractured',\n", " 'fractures',\n", " 'fragment',\n", " 'fragments',\n", " 'friend',\n", " 'frontal',\n", " 'fully',\n", " 'fusion',\n", " 'fx',\n", " 'galleazi',\n", " 'gap',\n", " 'gapping',\n", " 'gaps',\n", " 'gas',\n", " 'generate',\n", " 'generated',\n", " 'gentili',\n", " 'given',\n", " 'globules',\n", " 'grade',\n", " 'greater',\n", " 'greatest',\n", " 'greenstick',\n", " 'gross',\n", " 'grossly',\n", " 'half',\n", " 'hamate',\n", " 'hand',\n", " 'hannah',\n", " 'hardware',\n", " 'harris',\n", " 'head',\n", " 'healed',\n", " 'healing',\n", " 'health',\n", " 'heather',\n", " 'heavily',\n", " 'helical',\n", " 'helling',\n", " 'helpful',\n", " 'hemarthrosis',\n", " 'hematoma',\n", " 'hemorrhage',\n", " 'high',\n", " 'highly',\n", " 'history',\n", " 'hours',\n", " 'however',\n", " 'hr',\n", " 'huang',\n", " 'hughes',\n", " 'humeral',\n", " 'humerus',\n", " 'humpback',\n", " 'hydroxyapatite',\n", " 'identified',\n", " 'image',\n", " 'imaged',\n", " 'images',\n", " 'imaging',\n", " 'immediate',\n", " 'impacted',\n", " 'impaction',\n", " 'impax',\n", " 'impinge',\n", " 'implies',\n", " 'impression',\n", " 'improve',\n", " 'improved',\n", " 'improvement',\n", " 'incidental',\n", " 'inclination',\n", " 'included',\n", " 'includes',\n", " 'including',\n", " 'incomplete',\n", " 'incompletely',\n", " 'incongruent',\n", " 'incongruity',\n", " 'increase',\n", " 'increased',\n", " 'ind',\n", " 'independent',\n", " 'index',\n", " 'indicating',\n", " 'indication',\n", " 'inferior',\n", " 'inferiorly',\n", " 'initial',\n", " 'injuries',\n", " 'injury',\n", " 'instability',\n", " 'institution',\n", " 'insufficiency',\n", " 'intact',\n", " 'intercarpal',\n", " 'interfragmentary',\n", " 'internal',\n", " 'interosseous',\n", " 'interphalangeal',\n", " 'interposition',\n", " 'interpretation',\n", " 'interpreting',\n", " 'interspaces',\n", " 'interval',\n", " 'intra',\n", " 'intraarticular',\n", " 'intramedullary',\n", " 'intraosseous',\n", " 'intravenous',\n", " 'involve',\n", " 'involvement',\n", " 'involves',\n", " 'involving',\n", " 'irregularity',\n", " 'island',\n", " 'isolated',\n", " 'iterative',\n", " 'iv',\n", " 'jazbeh',\n", " 'joint',\n", " 'joints',\n", " 'junction',\n", " 'karen',\n", " 'known',\n", " 'l1',\n", " 'laceration',\n", " 'large',\n", " 'larger',\n", " 'lateral',\n", " 'lateralization',\n", " 'lauren',\n", " 'lawrence',\n", " 'laxity',\n", " 'least',\n", " 'left',\n", " 'length',\n", " 'lesion',\n", " 'less',\n", " 'level',\n", " 'levels',\n", " 'ligament',\n", " 'ligamentous',\n", " 'ligaments',\n", " 'like',\n", " 'likely',\n", " 'likewise',\n", " 'limit',\n", " 'limited',\n", " 'limiting',\n", " 'limits',\n", " 'lin',\n", " 'line',\n", " 'linear',\n", " 'lines',\n", " 'lip',\n", " 'lister',\n", " 'location',\n", " 'long',\n", " 'longstanding',\n", " 'longus',\n", " 'loss',\n", " 'low',\n", " 'lower',\n", " 'lt',\n", " 'lucency',\n", " 'lunate',\n", " 'lunotriquetral',\n", " 'made',\n", " 'madelung',\n", " 'main',\n", " 'maintained',\n", " 'maintenance',\n", " 'major',\n", " 'makes',\n", " 'makeup',\n", " 'malalignment',\n", " 'malunion',\n", " 'malunited',\n", " 'margin',\n", " 'marginal',\n", " 'margins',\n", " 'marked',\n", " 'markedly',\n", " 'mass',\n", " 'material',\n", " 'maximal',\n", " 'maxwell',\n", " 'may',\n", " 'mcglone',\n", " 'mcp',\n", " 'measurable',\n", " 'measured',\n", " 'measures',\n", " 'measuring',\n", " 'media',\n", " 'medial',\n", " 'medially',\n", " 'mediolateral',\n", " 'medullary',\n", " 'metacarpal',\n", " 'metacarpals',\n", " 'metacarpophalangeal',\n", " 'metadiaphysis',\n", " 'metallic',\n", " 'metaphyseal',\n", " 'metaphyses',\n", " 'metaphysis',\n", " 'mexico',\n", " 'mgy',\n", " 'mid',\n", " 'midcarpal',\n", " 'middle',\n", " 'migrated',\n", " 'mild',\n", " 'mildly',\n", " 'min',\n", " 'mineralization',\n", " 'mini',\n", " 'minimal',\n", " 'minimally',\n", " 'minimum',\n", " 'minor',\n", " 'mm',\n", " 'moderate',\n", " 'moderately',\n", " 'modern',\n", " 'morphology',\n", " 'motor',\n", " 'msk',\n", " 'multiplanar',\n", " 'multiple',\n", " 'muscle',\n", " 'muscles',\n", " 'musculotendinous',\n", " 'mvc',\n", " 'narrative',\n", " 'narrowing',\n", " 'near',\n", " 'necessary',\n", " 'neck',\n", " 'necrosis',\n", " 'need',\n", " 'negative',\n", " 'neurovascular',\n", " 'neutral',\n", " 'new',\n", " 'newly',\n", " 'noncontrast',\n", " 'nondisplaced',\n", " 'none',\n", " 'nonspecific',\n", " 'nonunion',\n", " 'nonunited',\n", " 'normal',\n", " 'normally',\n", " 'norman',\n", " 'notch',\n", " 'note',\n", " 'noted',\n", " 'number',\n", " 'numerous',\n", " 'oblique',\n", " 'obscure',\n", " 'obscured',\n", " 'obscures',\n", " 'obscuring',\n", " 'obtained',\n", " 'obvious',\n", " 'october',\n", " 'offset',\n", " 'old',\n", " 'olecranon',\n", " 'ongoing',\n", " 'open',\n", " 'operative',\n", " 'optimized',\n", " 'order',\n", " 'ordered',\n", " 'oriented',\n", " 'orif',\n", " 'origin',\n", " 'original',\n", " 'orthogonal',\n", " 'osborne',\n", " 'osseous',\n", " 'ossicle',\n", " 'ossicles',\n", " 'ossific',\n", " 'osteoarthritic',\n", " 'osteoarthrosis',\n", " 'osteonecrosis',\n", " 'osteopenia',\n", " 'osteopenic',\n", " 'osteophyte',\n", " 'osteophytes',\n", " 'osteophytosis',\n", " 'osteotomy',\n", " 'otherwise',\n", " 'outs',\n", " 'outside',\n", " 'overall',\n", " 'overlap',\n", " 'overlapping',\n", " 'overlying',\n", " 'override',\n", " 'overriding',\n", " 'pa',\n", " 'pacs',\n", " 'pads',\n", " 'pain',\n", " 'palmar',\n", " 'paris',\n", " 'part',\n", " 'partial',\n", " 'partially',\n", " 'particularly',\n", " 'partly',\n", " 'pass',\n", " 'passenger',\n", " 'passes',\n", " 'pathology',\n", " 'pathria',\n", " 'patient',\n", " 'pattern',\n", " 'peer',\n", " 'penticuff',\n", " 'per',\n", " 'percutaneous',\n", " 'performed',\n", " 'perihardware',\n", " 'perilunate',\n", " 'periosteal',\n", " 'persistent',\n", " 'petechial',\n", " 'phalangeal',\n", " 'phalanges',\n", " 'phalanx',\n", " 'physician',\n", " 'physiologic',\n", " 'physis',\n", " 'pieces',\n", " 'pin',\n", " 'pip',\n", " 'pisiform',\n", " 'place',\n", " 'placed',\n", " 'placement',\n", " 'plain',\n", " 'plane',\n", " 'planes',\n", " 'planning',\n", " 'plaster',\n", " 'plate',\n", " 'plating',\n", " 'please',\n", " 'plus',\n", " 'point',\n", " 'pole',\n", " 'polyarticular',\n", " 'poorly',\n", " 'portion',\n", " 'portions',\n", " 'positive',\n", " 'possibility',\n", " 'possible',\n", " 'possibly',\n", " 'post',\n", " 'posterior',\n", " 'posteriorly',\n", " 'posterolaterally',\n", " 'postreduction',\n", " 'postsurgical',\n", " 'posttraumatic',\n", " 'potential',\n", " 'practice',\n", " 'pre',\n", " 'predominant',\n", " 'predominantly',\n", " 'predominately',\n", " 'preferential',\n", " 'preliminary',\n", " 'preoperative',\n", " 'presence',\n", " 'present',\n", " 'preserved',\n", " 'presumably',\n", " 'presumed',\n", " 'previous',\n", " 'previously',\n", " 'principal',\n", " 'pringle',\n", " 'prior',\n", " 'probable',\n", " 'probably',\n", " 'process',\n", " 'processed',\n", " 'processes',\n", " 'processing',\n", " 'products',\n", " 'profundus',\n", " 'progressive',\n", " 'projecting',\n", " 'projection',\n", " 'projections',\n", " 'projects',\n", " 'prominent',\n", " 'pronator',\n", " 'pronounced',\n", " 'protocol',\n", " 'provided',\n", " 'provider',\n", " 'proximal',\n", " 'proximally',\n", " 'punctate',\n", " 'quadratus',\n", " 'question',\n", " 'questionable',\n", " 'quite',\n", " 'radial',\n", " 'radialis',\n", " 'radially',\n", " 'radiation',\n", " 'radii',\n", " 'radiocapitellar',\n", " 'radiocarpal',\n", " 'radiograph',\n", " 'radiographs',\n", " 'radiologist',\n", " 'radiopaque',\n", " 'radioscaphoid',\n", " 'radioulnar',\n", " 'radius',\n", " 'raising',\n", " 'randall',\n", " 'range',\n", " ...]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(df.columns)" ] }, { "cell_type": "code", "execution_count": null, "id": "feaab124", "metadata": { "id": "feaab124" }, "outputs": [], "source": [ "X = df.iloc[:, : 18]" ] }, { "cell_type": "code", "execution_count": null, "id": "c5673c64", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 468 }, "id": "c5673c64", "outputId": "c4f824a0-e924-4544-932e-427168621c4b" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeweightMaleSchedSkinMinsscheduled_case_durationABRAMS, REID ALLENHENTZEN, ERIC RICHARDLEEK, BRYAN TERRYMEUNIER, MATTHEW JOHNRECHNIC, MARKHealthyMild Systemic DiseaseSevere Systemic DiseaseChoice Per Patient on Day of SurgeryGeneralMonitored Anesthesia Care (MAC)Regionalheight
0422220.809011500010001010062.500
1383035.20909000010010010059.000
2613556.81659500100010010066.000
3572681.60606000010010000165.000
4552625.6012012000010100000164.000
.........................................................
140403488.01609000010010010069.000
141423244.8112012010000100000165.000
142413577.61858510000010100073.000
143642281.6012012000010010001062.559
144512368.0015017510000010000165.000
\n", "

145 rows × 18 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " Age weight Male SchedSkinMins scheduled_case_duration \\\n", "0 42 2220.8 0 90 115 \n", "1 38 3035.2 0 90 90 \n", "2 61 3556.8 1 65 95 \n", "3 57 2681.6 0 60 60 \n", "4 55 2625.6 0 120 120 \n", ".. ... ... ... ... ... \n", "140 40 3488.0 1 60 90 \n", "141 42 3244.8 1 120 120 \n", "142 41 3577.6 1 85 85 \n", "143 64 2281.6 0 120 120 \n", "144 51 2368.0 0 150 175 \n", "\n", " ABRAMS, REID ALLEN HENTZEN, ERIC RICHARD LEEK, BRYAN TERRY \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 0 0 \n", ".. ... ... ... \n", "140 0 0 0 \n", "141 1 0 0 \n", "142 1 0 0 \n", "143 0 0 0 \n", "144 1 0 0 \n", "\n", " MEUNIER, MATTHEW JOHN RECHNIC, MARK Healthy Mild Systemic Disease \\\n", "0 1 0 0 0 \n", "1 1 0 0 1 \n", "2 0 0 0 1 \n", "3 1 0 0 1 \n", "4 1 0 1 0 \n", ".. ... ... ... ... \n", "140 1 0 0 1 \n", "141 0 0 1 0 \n", "142 0 0 0 1 \n", "143 1 0 0 1 \n", "144 0 0 0 1 \n", "\n", " Severe Systemic Disease Choice Per Patient on Day of Surgery General \\\n", "0 1 0 1 \n", "1 0 0 1 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 0 0 \n", ".. ... ... ... \n", "140 0 0 1 \n", "141 0 0 0 \n", "142 0 1 0 \n", "143 0 0 0 \n", "144 0 0 0 \n", "\n", " Monitored Anesthesia Care (MAC) Regional height \n", "0 0 0 62.500 \n", "1 0 0 59.000 \n", "2 0 0 66.000 \n", "3 0 1 65.000 \n", "4 0 1 64.000 \n", ".. ... ... ... \n", "140 0 0 69.000 \n", "141 0 1 65.000 \n", "142 0 0 73.000 \n", "143 1 0 62.559 \n", "144 0 1 65.000 \n", "\n", "[145 rows x 18 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "code", "execution_count": null, "id": "d00d3312", "metadata": { "id": "d00d3312" }, "outputs": [], "source": [ "X['bert'] = df['bert_text']" ] }, { "cell_type": "code", "execution_count": null, "id": "a6458b26", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 659 }, "id": "a6458b26", "outputId": "5b43304c-2897-46df-9dc8-138eb0fbcb8f" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeweightMaleSchedSkinMinsscheduled_case_durationABRAMS, REID ALLENHENTZEN, ERIC RICHARDLEEK, BRYAN TERRYMEUNIER, MATTHEW JOHNRECHNIC, MARKHealthyMild Systemic DiseaseSevere Systemic DiseaseChoice Per Patient on Day of SurgeryGeneralMonitored Anesthesia Care (MAC)Regionalheightbert
0422220.809011500010001010062.500EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...
1383035.20909000010010010059.000EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...
2613556.81659500100010010066.000Narrative & Impression EXAM DESCRIPTION: X-RAY...
3572681.60606000010010000165.000EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...
4552625.6012012000010100000164.000EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...
............................................................
140403488.01609000010010010069.000EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU...
141423244.8112012010000100000165.000EXAM DESCRIPTION: X-RAY ELBOW 2 VIEWS - LEFT ...
142413577.61858510000010100073.000EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...
143642281.6012012000010010001062.559EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU...
144512368.0015017510000010000165.000EXAM DESCRIPTION: CT RT UPPER EXTREMITY CLINI...
\n", "

145 rows × 19 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " Age weight Male SchedSkinMins scheduled_case_duration \\\n", "0 42 2220.8 0 90 115 \n", "1 38 3035.2 0 90 90 \n", "2 61 3556.8 1 65 95 \n", "3 57 2681.6 0 60 60 \n", "4 55 2625.6 0 120 120 \n", ".. ... ... ... ... ... \n", "140 40 3488.0 1 60 90 \n", "141 42 3244.8 1 120 120 \n", "142 41 3577.6 1 85 85 \n", "143 64 2281.6 0 120 120 \n", "144 51 2368.0 0 150 175 \n", "\n", " ABRAMS, REID ALLEN HENTZEN, ERIC RICHARD LEEK, BRYAN TERRY \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 0 0 \n", ".. ... ... ... \n", "140 0 0 0 \n", "141 1 0 0 \n", "142 1 0 0 \n", "143 0 0 0 \n", "144 1 0 0 \n", "\n", " MEUNIER, MATTHEW JOHN RECHNIC, MARK Healthy Mild Systemic Disease \\\n", "0 1 0 0 0 \n", "1 1 0 0 1 \n", "2 0 0 0 1 \n", "3 1 0 0 1 \n", "4 1 0 1 0 \n", ".. ... ... ... ... \n", "140 1 0 0 1 \n", "141 0 0 1 0 \n", "142 0 0 0 1 \n", "143 1 0 0 1 \n", "144 0 0 0 1 \n", "\n", " Severe Systemic Disease Choice Per Patient on Day of Surgery General \\\n", "0 1 0 1 \n", "1 0 0 1 \n", "2 0 0 1 \n", "3 0 0 0 \n", "4 0 0 0 \n", ".. ... ... ... \n", "140 0 0 1 \n", "141 0 0 0 \n", "142 0 1 0 \n", "143 0 0 0 \n", "144 0 0 0 \n", "\n", " Monitored Anesthesia Care (MAC) Regional height \\\n", "0 0 0 62.500 \n", "1 0 0 59.000 \n", "2 0 0 66.000 \n", "3 0 1 65.000 \n", "4 0 1 64.000 \n", ".. ... ... ... \n", "140 0 0 69.000 \n", "141 0 1 65.000 \n", "142 0 0 73.000 \n", "143 1 0 62.559 \n", "144 0 1 65.000 \n", "\n", " bert \n", "0 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... \n", "1 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... \n", "2 Narrative & Impression EXAM DESCRIPTION: X-RAY... \n", "3 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... \n", "4 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... \n", ".. ... \n", "140 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU... \n", "141 EXAM DESCRIPTION: X-RAY ELBOW 2 VIEWS - LEFT ... \n", "142 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... \n", "143 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU... \n", "144 EXAM DESCRIPTION: CT RT UPPER EXTREMITY CLINI... \n", "\n", "[145 rows x 19 columns]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "code", "execution_count": null, "id": "d7a5c02a", "metadata": { "id": "d7a5c02a" }, "outputs": [], "source": [ "cols = list(X.columns)" ] }, { "cell_type": "code", "execution_count": null, "id": "ca16c997", "metadata": { "id": "ca16c997" }, "outputs": [], "source": [ "cols[-1] = cols[0]\n", "cols[0] = 'bert'" ] }, { "cell_type": "code", "execution_count": null, "id": "ecea416e", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 659 }, "id": "ecea416e", "outputId": "b4ca9bae-3cae-45f2-a34d-6db68bca4883" }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
bertweightMaleSchedSkinMinsscheduled_case_durationABRAMS, REID ALLENHENTZEN, ERIC RICHARDLEEK, BRYAN TERRYMEUNIER, MATTHEW JOHNRECHNIC, MARKHealthyMild Systemic DiseaseSevere Systemic DiseaseChoice Per Patient on Day of SurgeryGeneralMonitored Anesthesia Care (MAC)RegionalheightAge
0EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...2220.809011500010001010062.50042
1EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...3035.20909000010010010059.00038
2Narrative & Impression EXAM DESCRIPTION: X-RAY...3556.81659500100010010066.00061
3EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...2681.60606000010010000165.00057
4EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...2625.6012012000010100000164.00055
............................................................
140EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU...3488.01609000010010010069.00040
141EXAM DESCRIPTION: X-RAY ELBOW 2 VIEWS - LEFT ...3244.8112012010000100000165.00042
142EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM...3577.61858510000010100073.00041
143EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU...2281.6012012000010010001062.55964
144EXAM DESCRIPTION: CT RT UPPER EXTREMITY CLINI...2368.0015017510000010000165.00051
\n", "

145 rows × 19 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ], "text/plain": [ " bert weight Male \\\n", "0 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 2220.8 0 \n", "1 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 3035.2 0 \n", "2 Narrative & Impression EXAM DESCRIPTION: X-RAY... 3556.8 1 \n", "3 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 2681.6 0 \n", "4 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 2625.6 0 \n", ".. ... ... ... \n", "140 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU... 3488.0 1 \n", "141 EXAM DESCRIPTION: X-RAY ELBOW 2 VIEWS - LEFT ... 3244.8 1 \n", "142 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMUM... 3577.6 1 \n", "143 EXAM DESCRIPTION: X-RAY WRIST COMPLETE MINIMU... 2281.6 0 \n", "144 EXAM DESCRIPTION: CT RT UPPER EXTREMITY CLINI... 2368.0 0 \n", "\n", " SchedSkinMins scheduled_case_duration ABRAMS, REID ALLEN \\\n", "0 90 115 0 \n", "1 90 90 0 \n", "2 65 95 0 \n", "3 60 60 0 \n", "4 120 120 0 \n", ".. ... ... ... \n", "140 60 90 0 \n", "141 120 120 1 \n", "142 85 85 1 \n", "143 120 120 0 \n", "144 150 175 1 \n", "\n", " HENTZEN, ERIC RICHARD LEEK, BRYAN TERRY MEUNIER, MATTHEW JOHN \\\n", "0 0 0 1 \n", "1 0 0 1 \n", "2 0 1 0 \n", "3 0 0 1 \n", "4 0 0 1 \n", ".. ... ... ... \n", "140 0 0 1 \n", "141 0 0 0 \n", "142 0 0 0 \n", "143 0 0 1 \n", "144 0 0 0 \n", "\n", " RECHNIC, MARK Healthy Mild Systemic Disease Severe Systemic Disease \\\n", "0 0 0 0 1 \n", "1 0 0 1 0 \n", "2 0 0 1 0 \n", "3 0 0 1 0 \n", "4 0 1 0 0 \n", ".. ... ... ... ... \n", "140 0 0 1 0 \n", "141 0 1 0 0 \n", "142 0 0 1 0 \n", "143 0 0 1 0 \n", "144 0 0 1 0 \n", "\n", " Choice Per Patient on Day of Surgery General \\\n", "0 0 1 \n", "1 0 1 \n", "2 0 1 \n", "3 0 0 \n", "4 0 0 \n", ".. ... ... \n", "140 0 1 \n", "141 0 0 \n", "142 1 0 \n", "143 0 0 \n", "144 0 0 \n", "\n", " Monitored Anesthesia Care (MAC) Regional height Age \n", "0 0 0 62.500 42 \n", "1 0 0 59.000 38 \n", "2 0 0 66.000 61 \n", "3 0 1 65.000 57 \n", "4 0 1 64.000 55 \n", ".. ... ... ... ... \n", "140 0 0 69.000 40 \n", "141 0 1 65.000 42 \n", "142 0 0 73.000 41 \n", "143 1 0 62.559 64 \n", "144 0 1 65.000 51 \n", "\n", "[145 rows x 19 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X[cols]" ] }, { "cell_type": "code", "execution_count": null, "id": "28a791d7", "metadata": { "id": "28a791d7" }, "outputs": [], "source": [ "X = X[cols]" ] }, { "cell_type": "code", "execution_count": null, "id": "cfa7675b", "metadata": { "id": "cfa7675b" }, "outputs": [], "source": [ "Y = df['actual_case_duration']" ] }, { "cell_type": "code", "execution_count": null, "id": "8e1e612a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8e1e612a", "outputId": "0addffbf-7a28-4864-bc52-103fcc6b80ab" }, "outputs": [ { "data": { "text/plain": [ "0 69\n", "1 110\n", "2 103\n", "3 70\n", "4 90\n", " ... \n", "140 76\n", "141 162\n", "142 111\n", "143 102\n", "144 233\n", "Name: actual_case_duration, Length: 145, dtype: int64" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y" ] }, { "cell_type": "code", "execution_count": null, "id": "a1363545", "metadata": { "id": "a1363545" }, "outputs": [], "source": [ "import scipy.stats\n", "\n", "def mean_confidence_interval(data, confidence=0.95):\n", " a = 1.0 * np.array(data)\n", " n = len(a)\n", " m, se = np.mean(a), scipy.stats.sem(a)\n", " h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)\n", " return m, m-h, m+h\n", "# https://stackoverflow.com/questions/15033511/compute-a-confidence-interval-from-sample-data" ] }, { "cell_type": "code", "execution_count": null, "id": "2e5885af", "metadata": { "id": "2e5885af" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "SEQ_LEN = 512 # we will cut/pad our sequences to a length of 50 tokens\n", "\n", "def tokenize(sentence):\n", " tokens = tokenizer.encode_plus(sentence, max_length=SEQ_LEN,\n", " truncation=True, padding='max_length',\n", " add_special_tokens=True, return_attention_mask=True,\n", " return_token_type_ids=False, return_tensors='tf')\n", " return tokens['input_ids'], tokens['attention_mask']\n", "\n", "# https://gist.github.com/jamescalam/95c7fe7779244015c99a60c7fb0fa722" ] }, { "cell_type": "code", "execution_count": null, "id": "e18b418c", "metadata": { "id": "e18b418c" }, "outputs": [], "source": [ "from tensorflow.keras.layers import concatenate\n", "from keras import metrics" ] }, { "cell_type": "code", "execution_count": null, "id": "JivjISFxbFaZ", "metadata": { "id": "JivjISFxbFaZ" }, "outputs": [], "source": [ "df_test = pd.read_csv('case_testing.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "mB2EslmS03kl", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mB2EslmS03kl", "outputId": "729aae2b-e3db-4ef6-870a-057ca6865d53" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n", "Epoch 1/10\n", "16/16 [==============================] - 20s 566ms/step - loss: 46.6489 - mean_squared_error: 3996.6333 - mean_absolute_error: 46.6489 - val_loss: 24.4380 - val_mean_squared_error: 813.8113 - val_mean_absolute_error: 24.4380\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 385ms/step - loss: 26.3915 - mean_squared_error: 1530.7649 - mean_absolute_error: 26.3915 - val_loss: 27.0104 - val_mean_squared_error: 919.8987 - val_mean_absolute_error: 27.0104\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 385ms/step - loss: 21.5999 - mean_squared_error: 941.9378 - mean_absolute_error: 21.5999 - val_loss: 29.7836 - val_mean_squared_error: 1279.8384 - val_mean_absolute_error: 29.7836\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 389ms/step - loss: 18.4751 - mean_squared_error: 799.2047 - mean_absolute_error: 18.4751 - val_loss: 32.0392 - val_mean_squared_error: 1281.1172 - val_mean_absolute_error: 32.0392\n", "Epoch 5/10\n", "16/16 [==============================] - 6s 388ms/step - loss: 15.9001 - mean_squared_error: 657.3810 - mean_absolute_error: 15.9001 - val_loss: 32.6698 - val_mean_squared_error: 1403.7461 - val_mean_absolute_error: 32.6698\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 12.3270 - mean_squared_error: 481.6977 - mean_absolute_error: 12.3270 - val_loss: 20.7796 - val_mean_squared_error: 589.2974 - val_mean_absolute_error: 20.7796\n", "Epoch 7/10\n", "16/16 [==============================] - 7s 414ms/step - loss: 9.6944 - mean_squared_error: 335.4913 - mean_absolute_error: 9.6944 - val_loss: 23.5291 - val_mean_squared_error: 751.3484 - val_mean_absolute_error: 23.5291\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 8.3951 - mean_squared_error: 254.5211 - mean_absolute_error: 8.3951 - val_loss: 27.5824 - val_mean_squared_error: 1263.9397 - val_mean_absolute_error: 27.5824\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 8.6806 - mean_squared_error: 201.5634 - mean_absolute_error: 8.6806 - val_loss: 28.3671 - val_mean_squared_error: 1193.1681 - val_mean_absolute_error: 28.3671\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 396ms/step - loss: 6.5423 - mean_squared_error: 121.8022 - mean_absolute_error: 6.5423 - val_loss: 18.9814 - val_mean_squared_error: 610.5544 - val_mean_absolute_error: 18.9814\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 18.981361389160156\n", "mse: 610.554443359375\n", "rmse: 24.70939989881128\n", "2\n", "Epoch 1/10\n", "16/16 [==============================] - 26s 648ms/step - loss: 46.8759 - mean_squared_error: 4013.0264 - mean_absolute_error: 46.8759 - val_loss: 36.7672 - val_mean_squared_error: 1711.2057 - val_mean_absolute_error: 36.7672\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 396ms/step - loss: 25.6901 - mean_squared_error: 1296.9333 - mean_absolute_error: 25.6901 - val_loss: 30.5125 - val_mean_squared_error: 2853.5093 - val_mean_absolute_error: 30.5125\n", "Epoch 3/10\n", "16/16 [==============================] - 7s 413ms/step - loss: 21.2274 - mean_squared_error: 907.5854 - mean_absolute_error: 21.2274 - val_loss: 42.8112 - val_mean_squared_error: 3679.4487 - val_mean_absolute_error: 42.8112\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 18.8736 - mean_squared_error: 805.0837 - mean_absolute_error: 18.8736 - val_loss: 28.4393 - val_mean_squared_error: 2047.5968 - val_mean_absolute_error: 28.4393\n", "Epoch 5/10\n", "16/16 [==============================] - 7s 418ms/step - loss: 15.6844 - mean_squared_error: 629.3807 - mean_absolute_error: 15.6844 - val_loss: 16.2689 - val_mean_squared_error: 475.0890 - val_mean_absolute_error: 16.2689\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 12.6370 - mean_squared_error: 475.4661 - mean_absolute_error: 12.6370 - val_loss: 24.7103 - val_mean_squared_error: 2106.9885 - val_mean_absolute_error: 24.7103\n", "Epoch 7/10\n", "16/16 [==============================] - 7s 421ms/step - loss: 11.6137 - mean_squared_error: 411.6028 - mean_absolute_error: 11.6137 - val_loss: 30.1132 - val_mean_squared_error: 2033.3340 - val_mean_absolute_error: 30.1132\n", "Epoch 8/10\n", "16/16 [==============================] - 7s 425ms/step - loss: 10.3672 - mean_squared_error: 299.5610 - mean_absolute_error: 10.3672 - val_loss: 34.2402 - val_mean_squared_error: 3203.9238 - val_mean_absolute_error: 34.2402\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 405ms/step - loss: 7.9803 - mean_squared_error: 221.0783 - mean_absolute_error: 7.9803 - val_loss: 17.2296 - val_mean_squared_error: 493.0542 - val_mean_absolute_error: 17.2296\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 405ms/step - loss: 7.1923 - mean_squared_error: 169.4480 - mean_absolute_error: 7.1923 - val_loss: 14.7769 - val_mean_squared_error: 398.1079 - val_mean_absolute_error: 14.7769\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 14.776934623718262\n", "mse: 398.1079406738281\n", "rmse: 19.952642448403374\n", "3\n", "Epoch 1/10\n", "16/16 [==============================] - 26s 726ms/step - loss: 54.8101 - mean_squared_error: 4972.4326 - mean_absolute_error: 54.8101 - val_loss: 16.3773 - val_mean_squared_error: 416.7838 - val_mean_absolute_error: 16.3773\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 396ms/step - loss: 27.3585 - mean_squared_error: 1425.0876 - mean_absolute_error: 27.3585 - val_loss: 15.3837 - val_mean_squared_error: 448.0744 - val_mean_absolute_error: 15.3837\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 401ms/step - loss: 22.7017 - mean_squared_error: 1138.1315 - mean_absolute_error: 22.7017 - val_loss: 19.0043 - val_mean_squared_error: 512.5349 - val_mean_absolute_error: 19.0043\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 402ms/step - loss: 20.4842 - mean_squared_error: 970.0302 - mean_absolute_error: 20.4842 - val_loss: 16.8925 - val_mean_squared_error: 482.2722 - val_mean_absolute_error: 16.8925\n", "Epoch 5/10\n", "16/16 [==============================] - 7s 426ms/step - loss: 17.7727 - mean_squared_error: 798.7726 - mean_absolute_error: 17.7727 - val_loss: 17.9508 - val_mean_squared_error: 452.0657 - val_mean_absolute_error: 17.9508\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 406ms/step - loss: 13.8149 - mean_squared_error: 553.7028 - mean_absolute_error: 13.8149 - val_loss: 13.3315 - val_mean_squared_error: 252.4995 - val_mean_absolute_error: 13.3315\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 403ms/step - loss: 12.8694 - mean_squared_error: 456.7632 - mean_absolute_error: 12.8694 - val_loss: 19.0132 - val_mean_squared_error: 428.1066 - val_mean_absolute_error: 19.0132\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 12.1824 - mean_squared_error: 436.2264 - mean_absolute_error: 12.1824 - val_loss: 17.3174 - val_mean_squared_error: 512.1711 - val_mean_absolute_error: 17.3174\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 403ms/step - loss: 11.2290 - mean_squared_error: 346.5132 - mean_absolute_error: 11.2290 - val_loss: 15.6005 - val_mean_squared_error: 416.6655 - val_mean_absolute_error: 15.6005\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 407ms/step - loss: 8.7218 - mean_squared_error: 241.5640 - mean_absolute_error: 8.7218 - val_loss: 16.3083 - val_mean_squared_error: 419.8352 - val_mean_absolute_error: 16.3083\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 16.308277130126953\n", "mse: 419.835205078125\n", "rmse: 20.48988055304679\n", "4\n", "Epoch 1/10\n", "16/16 [==============================] - 25s 656ms/step - loss: 46.2391 - mean_squared_error: 3857.1001 - mean_absolute_error: 46.2391 - val_loss: 30.2236 - val_mean_squared_error: 1434.1438 - val_mean_absolute_error: 30.2236\n", "Epoch 2/10\n", "16/16 [==============================] - 7s 421ms/step - loss: 23.7975 - mean_squared_error: 1240.3252 - mean_absolute_error: 23.7975 - val_loss: 17.3686 - val_mean_squared_error: 464.5977 - val_mean_absolute_error: 17.3686\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 401ms/step - loss: 20.0567 - mean_squared_error: 919.6838 - mean_absolute_error: 20.0567 - val_loss: 25.4258 - val_mean_squared_error: 1748.2682 - val_mean_absolute_error: 25.4258\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 401ms/step - loss: 17.2865 - mean_squared_error: 757.0361 - mean_absolute_error: 17.2865 - val_loss: 26.1409 - val_mean_squared_error: 1273.0339 - val_mean_absolute_error: 26.1409\n", "Epoch 5/10\n", "16/16 [==============================] - 6s 402ms/step - loss: 14.7643 - mean_squared_error: 580.7484 - mean_absolute_error: 14.7643 - val_loss: 23.1793 - val_mean_squared_error: 1176.3580 - val_mean_absolute_error: 23.1793\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 404ms/step - loss: 11.4672 - mean_squared_error: 382.0633 - mean_absolute_error: 11.4672 - val_loss: 14.4316 - val_mean_squared_error: 287.1018 - val_mean_absolute_error: 14.4316\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 403ms/step - loss: 9.5826 - mean_squared_error: 301.8816 - mean_absolute_error: 9.5826 - val_loss: 21.8422 - val_mean_squared_error: 1267.8655 - val_mean_absolute_error: 21.8422\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 406ms/step - loss: 9.1591 - mean_squared_error: 271.8867 - mean_absolute_error: 9.1591 - val_loss: 30.4112 - val_mean_squared_error: 1265.4841 - val_mean_absolute_error: 30.4112\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 402ms/step - loss: 7.8254 - mean_squared_error: 171.3663 - mean_absolute_error: 7.8254 - val_loss: 18.6115 - val_mean_squared_error: 426.3449 - val_mean_absolute_error: 18.6115\n", "Epoch 10/10\n", "16/16 [==============================] - 7s 423ms/step - loss: 6.9302 - mean_squared_error: 141.0138 - mean_absolute_error: 6.9302 - val_loss: 14.5334 - val_mean_squared_error: 284.0012 - val_mean_absolute_error: 14.5334\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 14.533406257629395\n", "mse: 284.00115966796875\n", "rmse: 16.852333953134465\n", "5\n", "Epoch 1/10\n", "16/16 [==============================] - 21s 589ms/step - loss: 50.4853 - mean_squared_error: 4311.5381 - mean_absolute_error: 50.4853 - val_loss: 36.6460 - val_mean_squared_error: 1511.0992 - val_mean_absolute_error: 36.6460\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 402ms/step - loss: 26.5454 - mean_squared_error: 1335.6312 - mean_absolute_error: 26.5454 - val_loss: 29.6112 - val_mean_squared_error: 1844.8209 - val_mean_absolute_error: 29.6112\n", "Epoch 3/10\n", "16/16 [==============================] - 7s 426ms/step - loss: 23.2541 - mean_squared_error: 1134.9839 - mean_absolute_error: 23.2541 - val_loss: 25.5423 - val_mean_squared_error: 1164.0361 - val_mean_absolute_error: 25.5423\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 404ms/step - loss: 20.0800 - mean_squared_error: 960.0741 - mean_absolute_error: 20.0800 - val_loss: 22.4707 - val_mean_squared_error: 635.8213 - val_mean_absolute_error: 22.4707\n", "Epoch 5/10\n", "16/16 [==============================] - 7s 424ms/step - loss: 17.4392 - mean_squared_error: 784.3231 - mean_absolute_error: 17.4392 - val_loss: 32.4798 - val_mean_squared_error: 1422.3438 - val_mean_absolute_error: 32.4798\n", "Epoch 6/10\n", "16/16 [==============================] - 7s 411ms/step - loss: 14.9957 - mean_squared_error: 597.9553 - mean_absolute_error: 14.9957 - val_loss: 19.1292 - val_mean_squared_error: 578.0411 - val_mean_absolute_error: 19.1292\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 403ms/step - loss: 13.6220 - mean_squared_error: 497.3582 - mean_absolute_error: 13.6220 - val_loss: 26.1810 - val_mean_squared_error: 1085.8984 - val_mean_absolute_error: 26.1810\n", "Epoch 8/10\n", "16/16 [==============================] - 7s 425ms/step - loss: 11.0724 - mean_squared_error: 388.7359 - mean_absolute_error: 11.0724 - val_loss: 21.7325 - val_mean_squared_error: 623.1096 - val_mean_absolute_error: 21.7325\n", "Epoch 9/10\n", "16/16 [==============================] - 7s 405ms/step - loss: 8.4866 - mean_squared_error: 291.3825 - mean_absolute_error: 8.4866 - val_loss: 20.6490 - val_mean_squared_error: 572.5185 - val_mean_absolute_error: 20.6490\n", "Epoch 10/10\n", "16/16 [==============================] - 7s 429ms/step - loss: 8.1422 - mean_squared_error: 236.3636 - mean_absolute_error: 8.1422 - val_loss: 28.7561 - val_mean_squared_error: 1200.6166 - val_mean_absolute_error: 28.7561\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 28.75607681274414\n", "mse: 1200.6165771484375\n", "rmse: 34.64991453306108\n", "6\n", "Epoch 1/10\n", "16/16 [==============================] - 20s 620ms/step - loss: 43.1024 - mean_squared_error: 3136.5237 - mean_absolute_error: 43.1024 - val_loss: 23.8601 - val_mean_squared_error: 710.6071 - val_mean_absolute_error: 23.8601\n", "Epoch 2/10\n", "16/16 [==============================] - 7s 423ms/step - loss: 22.8502 - mean_squared_error: 1098.3306 - mean_absolute_error: 22.8502 - val_loss: 33.2904 - val_mean_squared_error: 2639.2778 - val_mean_absolute_error: 33.2904\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 404ms/step - loss: 19.7562 - mean_squared_error: 780.0094 - mean_absolute_error: 19.7562 - val_loss: 36.3724 - val_mean_squared_error: 2942.1360 - val_mean_absolute_error: 36.3724\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 405ms/step - loss: 16.0628 - mean_squared_error: 646.6743 - mean_absolute_error: 16.0628 - val_loss: 18.8508 - val_mean_squared_error: 561.0654 - val_mean_absolute_error: 18.8508\n", "Epoch 5/10\n", "16/16 [==============================] - 7s 427ms/step - loss: 12.6541 - mean_squared_error: 422.1867 - mean_absolute_error: 12.6541 - val_loss: 21.2715 - val_mean_squared_error: 833.6233 - val_mean_absolute_error: 21.2715\n", "Epoch 6/10\n", "16/16 [==============================] - 7s 406ms/step - loss: 10.2284 - mean_squared_error: 297.6600 - mean_absolute_error: 10.2284 - val_loss: 24.1720 - val_mean_squared_error: 781.6510 - val_mean_absolute_error: 24.1720\n", "Epoch 7/10\n", "16/16 [==============================] - 7s 424ms/step - loss: 9.2804 - mean_squared_error: 250.9150 - mean_absolute_error: 9.2804 - val_loss: 37.1640 - val_mean_squared_error: 2659.8718 - val_mean_absolute_error: 37.1640\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 403ms/step - loss: 7.9451 - mean_squared_error: 194.1178 - mean_absolute_error: 7.9451 - val_loss: 26.9092 - val_mean_squared_error: 964.8079 - val_mean_absolute_error: 26.9092\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 7.5933 - mean_squared_error: 147.2343 - mean_absolute_error: 7.5933 - val_loss: 35.5275 - val_mean_squared_error: 2593.9673 - val_mean_absolute_error: 35.5275\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 7.1485 - mean_squared_error: 126.6482 - mean_absolute_error: 7.1485 - val_loss: 27.7680 - val_mean_squared_error: 1043.2466 - val_mean_absolute_error: 27.7680\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 27.76795196533203\n", "mse: 1043.24658203125\n", "rmse: 32.2993278882278\n", "7\n", "Epoch 1/10\n", "16/16 [==============================] - 15s 506ms/step - loss: 43.4369 - mean_squared_error: 3173.8240 - mean_absolute_error: 43.4369 - val_loss: 36.6975 - val_mean_squared_error: 1515.2175 - val_mean_absolute_error: 36.6975\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 28.6534 - mean_squared_error: 1622.8112 - mean_absolute_error: 28.6534 - val_loss: 32.2223 - val_mean_squared_error: 2169.3262 - val_mean_absolute_error: 32.2223\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 21.6958 - mean_squared_error: 1007.1511 - mean_absolute_error: 21.6958 - val_loss: 20.4647 - val_mean_squared_error: 1044.1357 - val_mean_absolute_error: 20.4647\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 18.6510 - mean_squared_error: 832.8838 - mean_absolute_error: 18.6510 - val_loss: 28.7958 - val_mean_squared_error: 1532.1514 - val_mean_absolute_error: 28.7958\n", "Epoch 5/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 16.0831 - mean_squared_error: 652.8909 - mean_absolute_error: 16.0831 - val_loss: 20.7326 - val_mean_squared_error: 799.2883 - val_mean_absolute_error: 20.7326\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 14.1850 - mean_squared_error: 529.4182 - mean_absolute_error: 14.1850 - val_loss: 17.7058 - val_mean_squared_error: 492.8959 - val_mean_absolute_error: 17.7058\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 10.4267 - mean_squared_error: 380.6010 - mean_absolute_error: 10.4267 - val_loss: 32.5175 - val_mean_squared_error: 1867.6582 - val_mean_absolute_error: 32.5175\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 396ms/step - loss: 9.3113 - mean_squared_error: 295.4985 - mean_absolute_error: 9.3113 - val_loss: 31.1850 - val_mean_squared_error: 1913.2338 - val_mean_absolute_error: 31.1850\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 399ms/step - loss: 8.1765 - mean_squared_error: 241.0834 - mean_absolute_error: 8.1765 - val_loss: 23.8425 - val_mean_squared_error: 977.6897 - val_mean_absolute_error: 23.8425\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 400ms/step - loss: 6.8420 - mean_squared_error: 181.0989 - mean_absolute_error: 6.8420 - val_loss: 18.1302 - val_mean_squared_error: 728.7874 - val_mean_absolute_error: 18.1302\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 18.130205154418945\n", "mse: 728.787353515625\n", "rmse: 26.996061814931913\n", "8\n", "Epoch 1/10\n", "16/16 [==============================] - 16s 508ms/step - loss: 49.1335 - mean_squared_error: 3766.9285 - mean_absolute_error: 49.1335 - val_loss: 38.7321 - val_mean_squared_error: 1929.2648 - val_mean_absolute_error: 38.7321\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 25.8629 - mean_squared_error: 1225.3267 - mean_absolute_error: 25.8629 - val_loss: 37.9464 - val_mean_squared_error: 2863.6711 - val_mean_absolute_error: 37.9464\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 394ms/step - loss: 22.3968 - mean_squared_error: 1133.4264 - mean_absolute_error: 22.3968 - val_loss: 38.0297 - val_mean_squared_error: 2026.7555 - val_mean_absolute_error: 38.0297\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 394ms/step - loss: 19.2315 - mean_squared_error: 972.4156 - mean_absolute_error: 19.2315 - val_loss: 22.5312 - val_mean_squared_error: 710.9309 - val_mean_absolute_error: 22.5312\n", "Epoch 5/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 16.3772 - mean_squared_error: 767.7222 - mean_absolute_error: 16.3772 - val_loss: 36.1681 - val_mean_squared_error: 2000.5560 - val_mean_absolute_error: 36.1681\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 14.2181 - mean_squared_error: 603.9561 - mean_absolute_error: 14.2181 - val_loss: 27.7901 - val_mean_squared_error: 1434.0776 - val_mean_absolute_error: 27.7901\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 12.7447 - mean_squared_error: 541.6820 - mean_absolute_error: 12.7447 - val_loss: 36.6375 - val_mean_squared_error: 1987.8778 - val_mean_absolute_error: 36.6375\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 12.6958 - mean_squared_error: 427.4815 - mean_absolute_error: 12.6958 - val_loss: 29.1779 - val_mean_squared_error: 1521.8394 - val_mean_absolute_error: 29.1779\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 399ms/step - loss: 9.8992 - mean_squared_error: 363.2347 - mean_absolute_error: 9.8992 - val_loss: 18.1991 - val_mean_squared_error: 510.3538 - val_mean_absolute_error: 18.1991\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 396ms/step - loss: 8.3326 - mean_squared_error: 291.5731 - mean_absolute_error: 8.3326 - val_loss: 24.1199 - val_mean_squared_error: 956.7493 - val_mean_absolute_error: 24.1199\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 24.119884490966797\n", "mse: 956.749267578125\n", "rmse: 30.931363816975885\n", "9\n", "Epoch 1/10\n", "16/16 [==============================] - 15s 509ms/step - loss: 45.7833 - mean_squared_error: 3447.8528 - mean_absolute_error: 45.7833 - val_loss: 26.7600 - val_mean_squared_error: 913.0746 - val_mean_absolute_error: 26.7600\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 25.9168 - mean_squared_error: 1439.4403 - mean_absolute_error: 25.9168 - val_loss: 15.0151 - val_mean_squared_error: 621.3068 - val_mean_absolute_error: 15.0151\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 21.0310 - mean_squared_error: 999.2832 - mean_absolute_error: 21.0310 - val_loss: 21.1435 - val_mean_squared_error: 917.4957 - val_mean_absolute_error: 21.1435\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 395ms/step - loss: 17.4418 - mean_squared_error: 749.8430 - mean_absolute_error: 17.4418 - val_loss: 13.2612 - val_mean_squared_error: 325.4115 - val_mean_absolute_error: 13.2612\n", "Epoch 5/10\n", "16/16 [==============================] - 6s 396ms/step - loss: 14.5540 - mean_squared_error: 634.7337 - mean_absolute_error: 14.5540 - val_loss: 30.6246 - val_mean_squared_error: 1517.2434 - val_mean_absolute_error: 30.6246\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 13.5227 - mean_squared_error: 541.6465 - mean_absolute_error: 13.5227 - val_loss: 24.8897 - val_mean_squared_error: 1112.8219 - val_mean_absolute_error: 24.8897\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 399ms/step - loss: 10.8897 - mean_squared_error: 341.6058 - mean_absolute_error: 10.8897 - val_loss: 28.2217 - val_mean_squared_error: 1226.0479 - val_mean_absolute_error: 28.2217\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 12.2341 - mean_squared_error: 343.0167 - mean_absolute_error: 12.2341 - val_loss: 16.6716 - val_mean_squared_error: 479.9389 - val_mean_absolute_error: 16.6716\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 12.2951 - mean_squared_error: 349.8379 - mean_absolute_error: 12.2951 - val_loss: 26.3770 - val_mean_squared_error: 1158.8776 - val_mean_absolute_error: 26.3770\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 397ms/step - loss: 8.1908 - mean_squared_error: 203.6116 - mean_absolute_error: 8.1908 - val_loss: 15.7148 - val_mean_squared_error: 649.5447 - val_mean_absolute_error: 15.7148\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 15.714774131774902\n", "mse: 649.544677734375\n", "rmse: 25.486166399330735\n", "10\n", "Epoch 1/10\n", "16/16 [==============================] - 15s 509ms/step - loss: 47.8582 - mean_squared_error: 3851.1575 - mean_absolute_error: 47.8582 - val_loss: 42.8036 - val_mean_squared_error: 2715.5071 - val_mean_absolute_error: 42.8036\n", "Epoch 2/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 28.4303 - mean_squared_error: 1377.1080 - mean_absolute_error: 28.4303 - val_loss: 21.5701 - val_mean_squared_error: 802.0464 - val_mean_absolute_error: 21.5701\n", "Epoch 3/10\n", "16/16 [==============================] - 6s 399ms/step - loss: 21.7284 - mean_squared_error: 1101.9967 - mean_absolute_error: 21.7284 - val_loss: 22.7482 - val_mean_squared_error: 841.3483 - val_mean_absolute_error: 22.7482\n", "Epoch 4/10\n", "16/16 [==============================] - 6s 400ms/step - loss: 19.4773 - mean_squared_error: 885.0285 - mean_absolute_error: 19.4773 - val_loss: 16.6339 - val_mean_squared_error: 368.2636 - val_mean_absolute_error: 16.6339\n", "Epoch 5/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 14.9218 - mean_squared_error: 540.1773 - mean_absolute_error: 14.9218 - val_loss: 34.0125 - val_mean_squared_error: 2316.7109 - val_mean_absolute_error: 34.0125\n", "Epoch 6/10\n", "16/16 [==============================] - 6s 398ms/step - loss: 13.4875 - mean_squared_error: 552.4150 - mean_absolute_error: 13.4875 - val_loss: 13.6835 - val_mean_squared_error: 257.8040 - val_mean_absolute_error: 13.6835\n", "Epoch 7/10\n", "16/16 [==============================] - 6s 400ms/step - loss: 11.9964 - mean_squared_error: 497.4959 - mean_absolute_error: 11.9964 - val_loss: 27.9454 - val_mean_squared_error: 1854.7687 - val_mean_absolute_error: 27.9454\n", "Epoch 8/10\n", "16/16 [==============================] - 6s 402ms/step - loss: 10.0358 - mean_squared_error: 372.3074 - mean_absolute_error: 10.0358 - val_loss: 20.3163 - val_mean_squared_error: 534.4374 - val_mean_absolute_error: 20.3163\n", "Epoch 9/10\n", "16/16 [==============================] - 6s 402ms/step - loss: 9.5315 - mean_squared_error: 289.7698 - mean_absolute_error: 9.5315 - val_loss: 29.8330 - val_mean_squared_error: 1831.0917 - val_mean_absolute_error: 29.8330\n", "Epoch 10/10\n", "16/16 [==============================] - 6s 404ms/step - loss: 9.5033 - mean_squared_error: 256.6641 - mean_absolute_error: 9.5033 - val_loss: 16.8077 - val_mean_squared_error: 563.9155 - val_mean_absolute_error: 16.8077\n", "dict_keys(['loss', 'mean_squared_error', 'mean_absolute_error', 'val_loss', 'val_mean_squared_error', 'val_mean_absolute_error'])\n", "mae: 16.80767059326172\n", "mse: 563.9154663085938\n", "rmse: 23.74690435211701\n" ] } ], "source": [ "X = df.iloc[:, : 18]\n", "X['bert'] = df['bert_text']\n", "cols = list(X.columns)\n", "cols[-1] = cols[0]\n", "cols[0] = 'bert'\n", "X = X[cols]\n", "Y = df['actual_case_duration']\n", "\n", "kfold = KFold(n_splits=10)\n", "rmse_scores = []\n", "mse_scores = []\n", "mae_scores = []\n", "test_dataset=1\n", "fold = 0\n", "for train_idx, test_idx in kfold.split(X, Y):\n", " fold = fold + 1\n", " print(fold)\n", " x_train_f = X.iloc[train_idx]\n", " y_train_f = Y.iloc[train_idx]\n", " x_test_f = X.iloc[test_idx]\n", " y_test_f = Y.iloc[test_idx]\n", "\n", " #print(len(x_test_f))\n", " \n", " # tokenize all text\n", " Xids_train = np.zeros((len(x_train_f), SEQ_LEN))\n", " Xmask_train = np.zeros((len(x_train_f), SEQ_LEN))\n", " \n", " Xids_test = np.zeros((len(x_test_f), SEQ_LEN))\n", " Xmask_test = np.zeros((len(x_test_f), SEQ_LEN))\n", " \n", " for i, sentence in enumerate(x_train_f[x_train_f.columns[0]]):\n", " Xids_train[i, :], Xmask_train[i, :] = tokenize(sentence)\n", " \n", " for i, sentence in enumerate(x_test_f[x_test_f.columns[0]]):\n", " Xids_test[i, :], Xmask_test[i, :] = tokenize(sentence)\n", " \n", " # make tf train dataset\n", " labels = y_train_f\n", " f1 = x_train_f[x_train_f.columns[1]]\n", " f2 = x_train_f[x_train_f.columns[2]]\n", " f3 = x_train_f[x_train_f.columns[3]]\n", " f4 = x_train_f[x_train_f.columns[4]]\n", " f5 = x_train_f[x_train_f.columns[5]]\n", " f6 = x_train_f[x_train_f.columns[6]]\n", " f7 = x_train_f[x_train_f.columns[7]]\n", " f8 = x_train_f[x_train_f.columns[8]]\n", " f9 = x_train_f[x_train_f.columns[9]]\n", " f10 = x_train_f[x_train_f.columns[10]]\n", " f11 = x_train_f[x_train_f.columns[11]]\n", " f12 = x_train_f[x_train_f.columns[12]]\n", " f13 = x_train_f[x_train_f.columns[13]]\n", " f14 = x_train_f[x_train_f.columns[14]]\n", " f15 = x_train_f[x_train_f.columns[15]]\n", " f16 = x_train_f[x_train_f.columns[16]]\n", " f17 = x_train_f[x_train_f.columns[17]]\n", " f18 = x_train_f[x_train_f.columns[18]]\n", " \n", " train_dataset = tf.data.Dataset.from_tensor_slices(( Xids_train, Xmask_train,\n", " np.array(f1),\n", " np.array(f2),\n", " np.array(f3),\n", " np.array(f4),\n", " np.array(f5),\n", " np.array(f6),\n", " np.array(f7),\n", " np.array(f8),\n", " np.array(f9),\n", " np.array(f10),\n", " np.array(f11),\n", " np.array(f12),\n", " np.array(f13),\n", " np.array(f14),\n", " np.array(f15),\n", " np.array(f16),\n", " np.array(f17),\n", " np.array(f18),\n", " np.array(labels)))\n", " \n", " # make tf test dataset\n", " labels_test = y_test_f\n", " f1 = x_test_f[x_test_f.columns[1]]\n", " f2 = x_test_f[x_test_f.columns[2]]\n", " f3 = x_test_f[x_test_f.columns[3]]\n", " f4 = x_test_f[x_test_f.columns[4]]\n", " f5 = x_test_f[x_test_f.columns[5]]\n", " f6 = x_test_f[x_test_f.columns[6]]\n", " f7 = x_test_f[x_test_f.columns[7]]\n", " f8 = x_test_f[x_test_f.columns[8]]\n", " f9 = x_test_f[x_test_f.columns[9]]\n", " f10 = x_test_f[x_test_f.columns[10]]\n", " f11 = x_test_f[x_test_f.columns[11]]\n", " f12 = x_test_f[x_test_f.columns[12]]\n", " f13 = x_test_f[x_test_f.columns[13]]\n", " f14 = x_test_f[x_test_f.columns[14]]\n", " f15 = x_test_f[x_test_f.columns[15]]\n", " f16 = x_test_f[x_test_f.columns[16]]\n", " f17 = x_test_f[x_test_f.columns[17]]\n", " f18 = x_test_f[x_test_f.columns[18]]\n", " \n", " test_dataset = tf.data.Dataset.from_tensor_slices(( Xids_test, Xmask_test,\n", " np.array(f1),\n", " np.array(f2),\n", " np.array(f3),\n", " np.array(f4),\n", " np.array(f5),\n", " np.array(f6),\n", " np.array(f7),\n", " np.array(f8),\n", " np.array(f9),\n", " np.array(f10),\n", " np.array(f11),\n", " np.array(f12),\n", " np.array(f13),\n", " np.array(f14),\n", " np.array(f15),\n", " np.array(f16),\n", " np.array(f17),\n", " np.array(f18),\n", " np.array(labels_test)\n", " ))\n", " \n", " def mapfunc(input_ids, masks, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, labels_in):\n", " return {'input_ids': input_ids, 'attention_mask': masks, \n", " 'f1': f1, 'f2': f2, 'f3': f3, 'f4': f4, 'f5': f5, 'f6': f6, 'f7': f7, 'f8': f8, \n", " 'f9': f9, 'f10': f10, 'f11': f11, 'f12': f12, 'f13': f13, 'f14': f14, 'f15': f15, 'f16': f16, \n", " 'f17': f17, 'f18': f18}, labels_in\n", " # print(len(test_dataset))\n", " train_dataset = train_dataset.map(mapfunc)\n", " test_dataset = test_dataset.map(mapfunc)\n", " # print(len(test_dataset))\n", " \n", " BATCH_SIZE = 8\n", " \n", " train_dataset = train_dataset.shuffle(100000).batch(BATCH_SIZE, drop_remainder=True)\n", " test_dataset = test_dataset.shuffle(100000).batch(BATCH_SIZE, drop_remainder=True)\n", "\n", " # print(len(test_dataset))\n", " \n", " input_ids = tf.keras.layers.Input(shape=(512,), name='input_ids', dtype='int32')\n", " mask = tf.keras.layers.Input(shape=(512,), name='attention_mask', dtype='int32')\n", "\n", " \n", " embeddings = clin_bert(input_ids, attention_mask=mask)[0]\n", "\n", "\n", " X = tf.keras.layers.Flatten()(embeddings)\n", " X = tf.keras.Model(inputs={'input_ids': input_ids,'attention_mask': mask}, outputs=X)\n", "\n", " ff1 = tf.keras.layers.Input(shape=(1,), name='f1')\n", " ff2 = tf.keras.layers.Input(shape=(1,), name='f2')\n", " ff3 = tf.keras.layers.Input(shape=(1,), name='f3')\n", " ff4 = tf.keras.layers.Input(shape=(1,), name='f4')\n", " ff5 = tf.keras.layers.Input(shape=(1,), name='f5')\n", " ff6 = tf.keras.layers.Input(shape=(1,), name='f6')\n", " ff7 = tf.keras.layers.Input(shape=(1,), name='f7')\n", " ff8 = tf.keras.layers.Input(shape=(1,), name='f8')\n", " ff9 = tf.keras.layers.Input(shape=(1,), name='f9')\n", " ff10 = tf.keras.layers.Input(shape=(1,), name='f10')\n", " ff11 = tf.keras.layers.Input(shape=(1,), name='f11')\n", " ff12 = tf.keras.layers.Input(shape=(1,), name='f12')\n", " ff13 = tf.keras.layers.Input(shape=(1,), name='f13')\n", " ff14 = tf.keras.layers.Input(shape=(1,), name='f14')\n", " ff15 = tf.keras.layers.Input(shape=(1,), name='f15')\n", " ff16 = tf.keras.layers.Input(shape=(1,), name='f16')\n", " ff17 = tf.keras.layers.Input(shape=(1,), name='f17')\n", " ff18 = tf.keras.layers.Input(shape=(1,), name='f18')\n", "\n", " other_input = concatenate([ff1, ff2, ff3, ff4, ff5, ff6, ff7, ff8, ff9, ff10, ff11, ff12, ff13, ff14, ff15, ff16, ff17, ff18])\n", "\n", " Y = tf.keras.layers.Flatten()(other_input)\n", " Y = tf.keras.Model(inputs={'f1': ff1,'f2': ff2,'f3': ff3,'f4': ff4,'f5': ff5,'f6': ff6,'f7': ff7,'f8': ff8,\n", " 'f9': ff9,'f10': ff10,'f11': ff11,'f12': ff12,'f13': ff13,'f14': ff14,'f15': ff15, \n", " 'f16': ff16,'f17': ff17,'f18': ff18}, outputs=Y)\n", "\n", " combined = concatenate([X.output, Y.output])\n", "\n", " z = tf.keras.layers.Dense(256, activation=\"relu\")(combined)\n", " z = tf.keras.layers.Dense(128, activation=\"relu\")(z)\n", " z = tf.keras.layers.Dense(1, activation=\"linear\")(z)\n", "\n", "\n", " model = tf.keras.Model(inputs={'input_ids': input_ids,'attention_mask': mask, 'f1': ff1,'f2': ff2,'f3': ff3,\n", " 'f4': ff4,'f5': ff5,'f6': ff6,'f7': ff7,'f8': ff8,'f9': ff9,'f10': ff10,\n", " 'f11': ff11,'f12': ff12,'f13': ff13,'f14': ff14,'f15': ff15,'f16': ff16,\n", " 'f17': ff17,'f18': ff18},\n", " outputs=z)\n", "\n", " model.layers[20].trainable = False\n", "\n", " opt = tf.keras.optimizers.Adam(learning_rate=0.00005, decay = 0.000001) \n", " \n", " model.compile(loss='mean_absolute_error', optimizer=opt, metrics=[metrics.mean_squared_error,\n", " metrics.mean_absolute_error])\n", " \n", " history = model.fit(train_dataset,\n", " validation_data=test_dataset, \n", " epochs=10)\n", " \n", " print(history.history.keys())\n", " print(\"mae:\",history.history['val_mean_absolute_error'][-1])\n", " mae_scores.append(history.history['val_mean_absolute_error'][-1])\n", " print(\"mse:\",history.history['val_mean_squared_error'][-1])\n", " mse_scores.append(history.history['val_mean_squared_error'][-1])\n", " print(\"rmse:\",np.sqrt(history.history['val_mean_squared_error'][-1]))\n", " rmse_scores.append(np.sqrt(history.history['val_mean_squared_error'][-1]))\n", " \n", " del(X)\n", " del(Y)\n", " del(train_dataset)\n", " del(test_dataset)\n", " keras.backend.clear_session()\n", " del(model)\n", " del(history)\n", " del(z)\n", " del(opt)\n", " del(combined)\n", " del(x_train_f)\n", " del(y_train_f)\n", " del(x_test_f)\n", " del(y_test_f)\n", "\n", " X = df.iloc[:, : 18]\n", " X['bert'] = df['bert_text']\n", " cols = list(X.columns)\n", " cols[-1] = cols[0]\n", " cols[0] = 'bert'\n", " X = X[cols]\n", " Y = df['actual_case_duration']" ] }, { "cell_type": "code", "execution_count": null, "id": "iph-6CIW03hE", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iph-6CIW03hE", "outputId": "014aca3f-3f39-4628-95bc-bb7816f1d8aa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mae_scores: (19.58965425491333, 15.773073885746182, 23.40623462408048)\n", "mse_scores: (685.5358673095703, 471.6130029669391, 899.4587316522015)\n", "rmse_scores: (25.61139956580403, 21.509459952669204, 29.71333917893886)\n" ] } ], "source": [ "print(\"mae_scores:\", mean_confidence_interval(mae_scores))\n", "print(\"mse_scores:\", mean_confidence_interval(mse_scores))\n", "print(\"rmse_scores:\", mean_confidence_interval(rmse_scores))" ] }, { "cell_type": "code", "execution_count": null, "id": "FtZg_gx0C3OD", "metadata": { "id": "FtZg_gx0C3OD" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "D2yRR0di-zte", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "D2yRR0di-zte", "outputId": "172bfbd6-f0ff-422f-e642-c37186b4b09f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_2\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_ids (InputLayer) [(None, 512)] 0 [] \n", " \n", " attention_mask (InputLayer) [(None, 512)] 0 [] \n", " \n", " f1 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f2 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f3 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f4 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f5 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f6 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f7 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f8 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f9 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f10 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f11 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f12 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f13 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f14 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f15 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f16 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f17 (InputLayer) [(None, 1)] 0 [] \n", " \n", " f18 (InputLayer) [(None, 1)] 0 [] \n", " \n", " tf_bert_model (TFBertModel) TFBaseModelOutputWi 108310272 ['input_ids[0][0]', \n", " thPoolingAndCrossAt 'attention_mask[0][0]'] \n", " tentions(last_hidde \n", " n_state=(None, 512, \n", " 768), \n", " pooler_output=(Non \n", " e, 768), \n", " past_key_values=No \n", " ne, hidden_states=N \n", " one, attentions=Non \n", " e, cross_attentions \n", " =None) \n", " \n", " concatenate (Concatenate) (None, 18) 0 ['f1[0][0]', \n", " 'f2[0][0]', \n", " 'f3[0][0]', \n", " 'f4[0][0]', \n", " 'f5[0][0]', \n", " 'f6[0][0]', \n", " 'f7[0][0]', \n", " 'f8[0][0]', \n", " 'f9[0][0]', \n", " 'f10[0][0]', \n", " 'f11[0][0]', \n", " 'f12[0][0]', \n", " 'f13[0][0]', \n", " 'f14[0][0]', \n", " 'f15[0][0]', \n", " 'f16[0][0]', \n", " 'f17[0][0]', \n", " 'f18[0][0]'] \n", " \n", " flatten (Flatten) (None, 393216) 0 ['tf_bert_model[9][0]'] \n", " \n", " flatten_1 (Flatten) (None, 18) 0 ['concatenate[0][0]'] \n", " \n", " concatenate_1 (Concatenate) (None, 393234) 0 ['flatten[0][0]', \n", " 'flatten_1[0][0]'] \n", " \n", " dense (Dense) (None, 256) 100668160 ['concatenate_1[0][0]'] \n", " \n", " dense_1 (Dense) (None, 128) 32896 ['dense[0][0]'] \n", " \n", " dense_2 (Dense) (None, 1) 129 ['dense_1[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 209,011,457\n", "Trainable params: 100,701,185\n", "Non-trainable params: 108,310,272\n", "__________________________________________________________________________________________________\n", "Epoch 1/10\n", "18/18 [==============================] - 18s 620ms/step - loss: 47.3423 - mean_squared_error: 4008.8945 - mean_absolute_error: 47.3423 - val_loss: 29.9976 - val_mean_squared_error: 1130.0610 - val_mean_absolute_error: 29.9976\n", "Epoch 2/10\n", "18/18 [==============================] - 10s 536ms/step - loss: 29.1246 - mean_squared_error: 1455.5372 - mean_absolute_error: 29.1246 - val_loss: 15.6915 - val_mean_squared_error: 458.8918 - val_mean_absolute_error: 15.6915\n", "Epoch 3/10\n", "18/18 [==============================] - 10s 549ms/step - loss: 22.2707 - mean_squared_error: 1057.6720 - mean_absolute_error: 22.2707 - val_loss: 15.6626 - val_mean_squared_error: 436.4203 - val_mean_absolute_error: 15.6626\n", "Epoch 4/10\n", "18/18 [==============================] - 10s 538ms/step - loss: 18.5569 - mean_squared_error: 853.0248 - mean_absolute_error: 18.5569 - val_loss: 15.9120 - val_mean_squared_error: 449.7988 - val_mean_absolute_error: 15.9120\n", "Epoch 5/10\n", "18/18 [==============================] - 9s 526ms/step - loss: 15.9935 - mean_squared_error: 747.2337 - mean_absolute_error: 15.9935 - val_loss: 15.8423 - val_mean_squared_error: 413.9730 - val_mean_absolute_error: 15.8423\n", "Epoch 6/10\n", "18/18 [==============================] - 9s 519ms/step - loss: 13.8896 - mean_squared_error: 556.6750 - mean_absolute_error: 13.8896 - val_loss: 17.9933 - val_mean_squared_error: 490.4284 - val_mean_absolute_error: 17.9933\n", "Epoch 7/10\n", "18/18 [==============================] - 9s 516ms/step - loss: 11.0643 - mean_squared_error: 413.2108 - mean_absolute_error: 11.0643 - val_loss: 19.5996 - val_mean_squared_error: 561.2172 - val_mean_absolute_error: 19.5996\n", "Epoch 8/10\n", "18/18 [==============================] - 9s 518ms/step - loss: 10.9985 - mean_squared_error: 327.9755 - mean_absolute_error: 10.9985 - val_loss: 25.4740 - val_mean_squared_error: 888.8347 - val_mean_absolute_error: 25.4740\n", "Epoch 9/10\n", "18/18 [==============================] - 9s 523ms/step - loss: 10.1875 - mean_squared_error: 271.1446 - mean_absolute_error: 10.1875 - val_loss: 26.7409 - val_mean_squared_error: 983.5016 - val_mean_absolute_error: 26.7409\n", "Epoch 10/10\n", "18/18 [==============================] - 9s 528ms/step - loss: 19.5605 - mean_squared_error: 586.5504 - mean_absolute_error: 19.5605 - val_loss: 16.1380 - val_mean_squared_error: 463.7458 - val_mean_absolute_error: 16.1380\n", "7/7 [==============================] - 4s 322ms/step\n" ] } ], "source": [ "rmse_scores = []\n", "mse_scores = []\n", "mae_scores = []\n", "\n", "X = df.iloc[:, : 18]\n", "X['bert'] = df['bert_text']\n", "cols = list(X.columns)\n", "cols[-1] = cols[0]\n", "cols[0] = 'bert'\n", "X = X[cols]\n", "Y = df['actual_case_duration']\n", "\n", "X_test = df_test.iloc[:, 1: 19]\n", "X_test['bert'] = df_test['bert_text']\n", "cols = list(X_test.columns)\n", "cols[-1] = cols[0]\n", "cols[0] = 'bert'\n", "X_test = X_test[cols]\n", "Y_test = df_test['actual_case_duration']\n", "\n", "x_train_f = X\n", "y_train_f = Y\n", "x_test_f = X_test\n", "y_test_f = Y_test\n", "\n", " \n", " # tokenize all text\n", "Xids_train = np.zeros((len(x_train_f), SEQ_LEN))\n", "Xmask_train = np.zeros((len(x_train_f), SEQ_LEN))\n", " \n", "Xids_test = np.zeros((len(x_test_f), SEQ_LEN))\n", "Xmask_test = np.zeros((len(x_test_f), SEQ_LEN))\n", " \n", "for i, sentence in enumerate(x_train_f[x_train_f.columns[0]]):\n", " Xids_train[i, :], Xmask_train[i, :] = tokenize(sentence)\n", " \n", "for i, sentence in enumerate(x_test_f[x_test_f.columns[0]]):\n", " Xids_test[i, :], Xmask_test[i, :] = tokenize(sentence)\n", " \n", " # make tf train dataset\n", "labels = y_train_f\n", "f1 = x_train_f[x_train_f.columns[1]]\n", "f2 = x_train_f[x_train_f.columns[2]]\n", "f3 = x_train_f[x_train_f.columns[3]]\n", "f4 = x_train_f[x_train_f.columns[4]]\n", "f5 = x_train_f[x_train_f.columns[5]]\n", "f6 = x_train_f[x_train_f.columns[6]]\n", "f7 = x_train_f[x_train_f.columns[7]]\n", "f8 = x_train_f[x_train_f.columns[8]]\n", "f9 = x_train_f[x_train_f.columns[9]]\n", "f10 = x_train_f[x_train_f.columns[10]]\n", "f11 = x_train_f[x_train_f.columns[11]]\n", "f12 = x_train_f[x_train_f.columns[12]]\n", "f13 = x_train_f[x_train_f.columns[13]]\n", "f14 = x_train_f[x_train_f.columns[14]]\n", "f15 = x_train_f[x_train_f.columns[15]]\n", "f16 = x_train_f[x_train_f.columns[16]]\n", "f17 = x_train_f[x_train_f.columns[17]]\n", "f18 = x_train_f[x_train_f.columns[18]]\n", " \n", "train_dataset = tf.data.Dataset.from_tensor_slices(( Xids_train, Xmask_train,\n", " np.array(f1),\n", " np.array(f2),\n", " np.array(f3),\n", " np.array(f4),\n", " np.array(f5),\n", " np.array(f6),\n", " np.array(f7),\n", " np.array(f8),\n", " np.array(f9),\n", " np.array(f10),\n", " np.array(f11),\n", " np.array(f12),\n", " np.array(f13),\n", " np.array(f14),\n", " np.array(f15),\n", " np.array(f16),\n", " np.array(f17),\n", " np.array(f18),\n", " np.array(labels)))\n", " \n", "# make tf test dataset\n", "labels_test = y_test_f\n", "f1 = x_test_f[x_test_f.columns[1]]\n", "f2 = x_test_f[x_test_f.columns[2]]\n", "f3 = x_test_f[x_test_f.columns[3]]\n", "f4 = x_test_f[x_test_f.columns[4]]\n", "f5 = x_test_f[x_test_f.columns[5]]\n", "f6 = x_test_f[x_test_f.columns[6]]\n", "f7 = x_test_f[x_test_f.columns[7]]\n", "f8 = x_test_f[x_test_f.columns[8]]\n", "f9 = x_test_f[x_test_f.columns[9]]\n", "f10 = x_test_f[x_test_f.columns[10]]\n", "f11 = x_test_f[x_test_f.columns[11]]\n", "f12 = x_test_f[x_test_f.columns[12]]\n", "f13 = x_test_f[x_test_f.columns[13]]\n", "f14 = x_test_f[x_test_f.columns[14]]\n", "f15 = x_test_f[x_test_f.columns[15]]\n", "f16 = x_test_f[x_test_f.columns[16]]\n", "f17 = x_test_f[x_test_f.columns[17]]\n", "f18 = x_test_f[x_test_f.columns[18]]\n", " \n", "test_dataset = tf.data.Dataset.from_tensor_slices(( Xids_test, Xmask_test,\n", " np.array(f1),\n", " np.array(f2),\n", " np.array(f3),\n", " np.array(f4),\n", " np.array(f5),\n", " np.array(f6),\n", " np.array(f7),\n", " np.array(f8),\n", " np.array(f9),\n", " np.array(f10),\n", " np.array(f11),\n", " np.array(f12),\n", " np.array(f13),\n", " np.array(f14),\n", " np.array(f15),\n", " np.array(f16),\n", " np.array(f17),\n", " np.array(f18),\n", " np.array(labels_test)\n", " ))\n", " \n", "def mapfunc(input_ids, masks, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, labels_in):\n", " return {'input_ids': input_ids, 'attention_mask': masks, \n", " 'f1': f1, 'f2': f2, 'f3': f3, 'f4': f4, 'f5': f5, 'f6': f6, 'f7': f7, 'f8': f8, \n", " 'f9': f9, 'f10': f10, 'f11': f11, 'f12': f12, 'f13': f13, 'f14': f14, 'f15': f15, 'f16': f16, \n", " 'f17': f17, 'f18': f18}, labels_in\n", "train_dataset = train_dataset.map(mapfunc)\n", "test_dataset = test_dataset.map(mapfunc)\n", " \n", "BATCH_SIZE = 8\n", " \n", "train_dataset = train_dataset.shuffle(100000, reshuffle_each_iteration=False).batch(BATCH_SIZE, drop_remainder=True)\n", "test_dataset = test_dataset.shuffle(100000, reshuffle_each_iteration=False).batch(BATCH_SIZE, drop_remainder=True)\n", "\n", " \n", "input_ids = tf.keras.layers.Input(shape=(512,), name='input_ids', dtype='int32')\n", "mask = tf.keras.layers.Input(shape=(512,), name='attention_mask', dtype='int32')\n", "\n", " \n", "embeddings = clin_bert(input_ids, attention_mask=mask)[0]\n", "clin_bert.trainable=False\n", "\n", "X = tf.keras.layers.Flatten()(embeddings)\n", "X = tf.keras.Model(inputs={'input_ids': input_ids,'attention_mask': mask}, outputs=X)\n", "\n", "ff1 = tf.keras.layers.Input(shape=(1,), name='f1')\n", "ff2 = tf.keras.layers.Input(shape=(1,), name='f2')\n", "ff3 = tf.keras.layers.Input(shape=(1,), name='f3')\n", "ff4 = tf.keras.layers.Input(shape=(1,), name='f4')\n", "ff5 = tf.keras.layers.Input(shape=(1,), name='f5')\n", "ff6 = tf.keras.layers.Input(shape=(1,), name='f6')\n", "ff7 = tf.keras.layers.Input(shape=(1,), name='f7')\n", "ff8 = tf.keras.layers.Input(shape=(1,), name='f8')\n", "ff9 = tf.keras.layers.Input(shape=(1,), name='f9')\n", "ff10 = tf.keras.layers.Input(shape=(1,), name='f10')\n", "ff11 = tf.keras.layers.Input(shape=(1,), name='f11')\n", "ff12 = tf.keras.layers.Input(shape=(1,), name='f12')\n", "ff13 = tf.keras.layers.Input(shape=(1,), name='f13')\n", "ff14 = tf.keras.layers.Input(shape=(1,), name='f14')\n", "ff15 = tf.keras.layers.Input(shape=(1,), name='f15')\n", "ff16 = tf.keras.layers.Input(shape=(1,), name='f16')\n", "ff17 = tf.keras.layers.Input(shape=(1,), name='f17')\n", "ff18 = tf.keras.layers.Input(shape=(1,), name='f18')\n", "\n", "other_input = concatenate([ff1, ff2, ff3, ff4, ff5, ff6, ff7, ff8, ff9, ff10, ff11, ff12, ff13, ff14, ff15, ff16, ff17, ff18])\n", "\n", "Y = tf.keras.layers.Flatten()(other_input)\n", "Y = tf.keras.Model(inputs={'f1': ff1,'f2': ff2,'f3': ff3,'f4': ff4,'f5': ff5,'f6': ff6,'f7': ff7,'f8': ff8,\n", " 'f9': ff9,'f10': ff10,'f11': ff11,'f12': ff12,'f13': ff13,'f14': ff14,'f15': ff15, \n", " 'f16': ff16,'f17': ff17,'f18': ff18}, outputs=Y)\n", "\n", "combined = concatenate([X.output, Y.output])\n", "\n", "z = tf.keras.layers.Dense(256, activation=\"relu\")(combined)\n", "z = tf.keras.layers.Dense(128, activation=\"relu\")(z)\n", "z = tf.keras.layers.Dense(1, activation=\"linear\")(z)\n", "\n", "model = tf.keras.Model(inputs={'input_ids': input_ids,'attention_mask': mask, 'f1': ff1,'f2': ff2,'f3': ff3,\n", " 'f4': ff4,'f5': ff5,'f6': ff6,'f7': ff7,'f8': ff8,'f9': ff9,'f10': ff10,\n", " 'f11': ff11,'f12': ff12,'f13': ff13,'f14': ff14,'f15': ff15,'f16': ff16,\n", " 'f17': ff17,'f18': ff18},\n", " outputs=z)\n", "\n", "model.summary()\n", "model.layers[20].trainable = False\n", "\n", "opt = tf.keras.optimizers.Adam(learning_rate=0.00005, decay = 0.000001)\n", " \n", "model.compile(loss='mean_absolute_error', optimizer=opt, metrics=[metrics.mean_squared_error,\n", " metrics.mean_absolute_error])\n", " \n", "history = model.fit(train_dataset,\n", " validation_data=test_dataset, \n", " epochs=10)\n", "\n", "preds = model.predict(test_dataset)\n", "preds = [t[0] for t in preds]\n", "# final\n", "del(X)\n", "del(Y)\n", "del(train_dataset)\n", "keras.backend.clear_session()\n", "del(model)\n", "del(history)\n", "del(z)\n", "del(opt)\n", "del(combined)" ] }, { "cell_type": "code", "execution_count": null, "id": "LWXYOcDb-zrH", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LWXYOcDb-zrH", "outputId": "6f9172f9-3e06-4ce5-d8df-93e8ef9cb74b", "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[89.342575,\n", " 84.96334,\n", " 73.85574,\n", " 89.143105,\n", " 74.004524,\n", " 89.45623,\n", " 100.072495,\n", " 102.35107,\n", " 62.37254,\n", " 90.96574,\n", " 118.32195,\n", " 111.199905,\n", " 122.317406,\n", " 106.99516,\n", " 84.8193,\n", " 69.322365,\n", " 83.56809,\n", " 116.3028,\n", " 102.3071,\n", " 67.5948,\n", " 99.52764,\n", " 112.82618,\n", " 97.93588,\n", " 60.369976,\n", " 101.93672,\n", " 98.007835,\n", " 111.98784,\n", " 114.860405,\n", " 79.936325,\n", " 64.58143,\n", " 78.6915,\n", " 77.065025,\n", " 117.50474,\n", " 107.42674,\n", " 83.88853,\n", " 101.741875,\n", " 80.27219,\n", " 76.06034,\n", " 86.25531,\n", " 109.2905,\n", " 129.69717,\n", " 68.16295,\n", " 87.04166,\n", " 85.51395,\n", " 110.27611,\n", " 99.25715,\n", " 79.004684,\n", " 89.421646,\n", " 85.628296,\n", " 90.18194,\n", " 90.29712,\n", " 89.49869,\n", " 92.2751,\n", " 82.29098,\n", " 83.80102,\n", " 86.92345]" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds" ] }, { "cell_type": "code", "execution_count": null, "id": "McABzZiT-zop", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "McABzZiT-zop", "outputId": "1014eec2-d913-4b83-9170-842ee2426562" }, "outputs": [ { "data": { "text/plain": [ "[({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " )]" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(test_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "zM4W0q4q-zmU", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zM4W0q4q-zmU", "outputId": "62b61211-bbb8-431c-d3e3-fa37d78ce52f" }, "outputs": [ { "data": { "text/plain": [ "[({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " ),\n", " ({'input_ids': ,\n", " 'attention_mask': ,\n", " 'f1': ,\n", " 'f2': ,\n", " 'f3': ,\n", " 'f4': ,\n", " 'f5': ,\n", " 'f6': ,\n", " 'f7': ,\n", " 'f8': ,\n", " 'f9': ,\n", " 'f10': ,\n", " 'f11': ,\n", " 'f12': ,\n", " 'f13': ,\n", " 'f14': ,\n", " 'f15': ,\n", " 'f16': ,\n", " 'f17': ,\n", " 'f18': },\n", " )]" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(test_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "jAhNNfSX-zkV", "metadata": { "id": "jAhNNfSX-zkV" }, "outputs": [], "source": [ "baseline = [90, 90, 60, 115, 80, 85, 65, 100, 60, 90, 90, 85, 75, 90, 55, 90, 60, 85, 75, 90, 90, 115, 105, 65, 65, 90, 85, 95, 60, 90, 55, 85, 95, 95, 75, 15, 110, 75, 60, 100, 115, 60, 90, 105, 85, 100, 60, 90, 85, 60, 85, 140, 145, 60, 60, 95]" ] }, { "cell_type": "code", "execution_count": null, "id": "ZVTGbBQZ-ziH", "metadata": { "id": "ZVTGbBQZ-ziH" }, "outputs": [], "source": [ "correct = [108, 92, 107, 80, 119, 114, 94, 89, 74, 103, 104, 108, 97, 86, 77, 77, 94, 126, 114, 153, 128, 112, 97, 68, 82, 81, 108, 90, 71, 102, 93, 72, 88, 147, 96, 69, 87, 101, 82, 121, 140, 72, 115, 80, 132, 83, 74, 85, 91, 81, 123, 78, 66, 80, 94, 90]" ] }, { "cell_type": "code", "execution_count": null, "id": "w3nfywk19PsP", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "w3nfywk19PsP", "outputId": "ecf02c93-c467-4da6-f826-68a61d8b6bd5" }, "outputs": [ { "data": { "text/plain": [ "16.138039316449845" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mae(preds, correct)" ] }, { "cell_type": "code", "execution_count": null, "id": "6nj5Mj_D9PqE", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6nj5Mj_D9PqE", "outputId": "5048a5bb-284e-44ca-9ac6-5ca7b7a2f7f6" }, "outputs": [ { "data": { "text/plain": [ "463.7457526501909" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mse(correct, preds)" ] }, { "cell_type": "code", "execution_count": null, "id": "Ub1WYq9B9PoB", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ub1WYq9B9PoB", "outputId": "a72e5863-580b-440a-f35c-c910db2ea0b6" }, "outputs": [ { "data": { "text/plain": [ "21.534756851429528" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mse(correct, preds, squared=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "9jHnlW4o9Pl8", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9jHnlW4o9Pl8", "outputId": "ea658761-fccd-41e6-ea0a-6f4850e34678" }, "outputs": [ { "data": { "text/plain": [ "866.9464285714286" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mse(baseline, correct)" ] }, { "cell_type": "code", "execution_count": null, "id": "p2luPxOG9Pj2", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p2luPxOG9Pj2", "outputId": "309e970e-af18-437a-9987-00b1b904d003" }, "outputs": [ { "data": { "text/plain": [ "29.443954024067974" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mse(correct, baseline, squared=False)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "09729d15af3946229be18564bbf46e8c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "24367f3f608b45769eb3d3451f3b5ab1": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "25cc16b8d2c04213a33aa7b7fc51c02e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b7e2ed5256384def91cdf5a687a7459d", "placeholder": "​", "style": "IPY_MODEL_b55aec6e6a554532954d751f5af6d917", "value": "Downloading: 100%" } }, "292304f6679d41c896f50b367a6f0f76": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "2f5a7c7b350a41b5826ac3b1ced445b2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "32ba13799e5140ef9c58149b40d86c2f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f07e4555320b4d9d9e4ef8d66b15d05f", "placeholder": "​", "style": "IPY_MODEL_744aaa73d0fa473da416e443e9ec2084", "value": "Downloading: 100%" } }, "3727962fa5b9484bb10504c2c59419f2": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "423fe2f31065474c992d39f158757e8a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_61ebaa62bb7e42098f2ea3cd95b2977e", "max": 435778770, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_94213341c0e14639bea3f607bfe4ddd9", "value": 435778770 } }, "43a59377f06044b5a1591fce94a6a5e3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a29169a9668844e4bebbe39f6dc699be", "IPY_MODEL_8d19c84ce93b49fc8e1cc944d9d9f28d", "IPY_MODEL_a15a1c50d6434021b5d6de4f3c888e0a" ], "layout": "IPY_MODEL_7d9fb47db87f44a48df5c0de2655f0fd" } }, "61ebaa62bb7e42098f2ea3cd95b2977e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6f8fb8284fc649118c0ba53d049e5c56": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_32ba13799e5140ef9c58149b40d86c2f", "IPY_MODEL_423fe2f31065474c992d39f158757e8a", "IPY_MODEL_bfd1e9a341e844729652573bee473b90" ], "layout": "IPY_MODEL_d2c94f525d9f4537bb86209aeacbb713" } }, "744aaa73d0fa473da416e443e9ec2084": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "7d9fb47db87f44a48df5c0de2655f0fd": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "872fd9f0c86d4f7399b24b6234d30326": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_98950ddfbd4e48eca132bd709cee6762", "max": 385, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_fa85e893966046399e87cc6b292eb47d", "value": 385 } }, "8c8cd3d51f9d4d28ad2eafc2abec6576": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8d19c84ce93b49fc8e1cc944d9d9f28d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_baea1444bda940a79c52ba497124fedc", "max": 213450, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_292304f6679d41c896f50b367a6f0f76", "value": 213450 } }, "94213341c0e14639bea3f607bfe4ddd9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "9626e8c5ef854eb68afc9b627ed10dfc": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "98950ddfbd4e48eca132bd709cee6762": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a15a1c50d6434021b5d6de4f3c888e0a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_24367f3f608b45769eb3d3451f3b5ab1", "placeholder": "​", "style": "IPY_MODEL_c5123eb9cb94456b8d9c460373eda0bd", "value": " 213k/213k [00:00<00:00, 290kB/s]" } }, "a29169a9668844e4bebbe39f6dc699be": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ca0c6ecd06854fffa5ffccbdd15107b3", "placeholder": "​", "style": "IPY_MODEL_09729d15af3946229be18564bbf46e8c", "value": "Downloading: 100%" } }, "a8abbd6c0aa642ea971de0354445d0f6": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ab98d6bdc3c44097837b57d68ed89f84": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9626e8c5ef854eb68afc9b627ed10dfc", "placeholder": "​", "style": "IPY_MODEL_8c8cd3d51f9d4d28ad2eafc2abec6576", "value": " 385/385 [00:00<00:00, 9.57kB/s]" } }, "b55aec6e6a554532954d751f5af6d917": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "b7e2ed5256384def91cdf5a687a7459d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "baea1444bda940a79c52ba497124fedc": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bfd1e9a341e844729652573bee473b90": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3727962fa5b9484bb10504c2c59419f2", "placeholder": "​", "style": "IPY_MODEL_2f5a7c7b350a41b5826ac3b1ced445b2", "value": " 436M/436M [00:17<00:00, 14.2MB/s]" } }, "c5123eb9cb94456b8d9c460373eda0bd": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ca0c6ecd06854fffa5ffccbdd15107b3": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d2c94f525d9f4537bb86209aeacbb713": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f07b2752f6394b638f441abb111e93e6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_25cc16b8d2c04213a33aa7b7fc51c02e", "IPY_MODEL_872fd9f0c86d4f7399b24b6234d30326", "IPY_MODEL_ab98d6bdc3c44097837b57d68ed89f84" ], "layout": "IPY_MODEL_a8abbd6c0aa642ea971de0354445d0f6" } }, "f07e4555320b4d9d9e4ef8d66b15d05f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fa85e893966046399e87cc6b292eb47d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 5 }