{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "EdSQsvTy58D-" }, "source": [ "**Clinical Reasoning with GPT-3**" ] }, { "cell_type": "markdown", "metadata": { "id": "YjL3gawM6Gl7" }, "source": [ "**Load Necessary Dependencies**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "29Rj6y0HHdi0" }, "outputs": [], "source": [ "import pandas as pd\n", "import cohere\n", "from datasets import load_dataset\n", "import openai\n", "import os\n", "import dsp\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J0XQcHKZMT6o" }, "outputs": [], "source": [ "try:\n", " import datasets\n", " root_path = '.'\n", "except ModuleNotFoundError:\n", " !git clone https://github.com/cgpotts/cs224u/\n", " !pip install -r cs224u/requirements.txt\n", " root_path = 'dsp'" ] }, { "cell_type": "code", "source": [ "key = ## API key ##\n", "os.environ[\"DSP_NOTEBOOK_CACHEDIR\"] = os.path.join(root_path, 'cache')\n", "openai_key = key" ], "metadata": { "id": "Wum_16i6PHoF" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "-NPFQ4NI6ZRx" }, "source": [ "**Load and clean questions**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9LIFXSpCDNhZ" }, "outputs": [], "source": [ "test = pd.read_json(\"/content/test.jsonl\", lines=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gZMX40jzC04Y" }, "outputs": [], "source": [ "test = test[test['meta_info']!= 'step1']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ihn8vt11DkIq" }, "outputs": [], "source": [ "test_orig = test.reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SeuQnfz9IGXr" }, "outputs": [], "source": [ "test_orig['question'] = test_orig['question'].str.replace('Which of the following', 'What')\n", "test_orig['question'] = test_orig['question'].str.replace('which of the following', 'what')" ] }, { "cell_type": "code", "source": [ "test_orig.to_csv('dev_cases.csv', index=False)" ], "metadata": { "id": "LkZp4bd7uqYN" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ygd1kdGsC_6c" }, "source": [ "**Selected Diagnosis Task Questions**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXkL2dLxDDxX" }, "outputs": [], "source": [ "diag_q = pd.read_csv('test_question_set.csv')" ] }, { "cell_type": "markdown", "metadata": { "id": "l-7JbJLy6rtl" }, "source": [ "Load DSP module and Openai API with key\n", "(Note must provide own key)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wi8JLEeX6K4N" }, "outputs": [], "source": [ "lm = dsp.GPT3(model='text-davinci-003', api_key=key)\n", "dsp.settings.configure(lm=lm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Q5XqwIcUA9FF" }, "outputs": [], "source": [ "Question = dsp.Type(prefix=\"Question:\", desc=\"${A patient passage to be analyzed and determine the patient's diagnosis}\")\n", "Answer = dsp.Type(prefix=\"Answer:\", desc=\"${a short factoid diagnosis, often between 1 and 9 words}\", format=dsp.format_answers)\n", "qa_template = dsp.Template(instructions=\"Diagnose the patients condition using the given passage with a short factoid answer.\", question=Question(), answer=Answer())" ] }, { "cell_type": "markdown", "metadata": { "id": "Zxdly7xHigNX" }, "source": [ "**Diagnosis: Few Shot with Chain of Thought (CoT)**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bqr19xg7efrr" }, "outputs": [], "source": [ "demo_example_CoT = [{'question':'Shortly after undergoing a bipolar prosthesis for a displaced femoral neck fracture of the left hip acquired after a fall the day before, an 80-year-old woman suddenly develops dyspnea. The surgery under general anesthesia with sevoflurane was uneventful, lasting 98 minutes, during which the patient maintained oxygen saturation readings of 96–100% on 8 L of oxygen. She has a history of hypertension, osteoporosis, and osteoarthritis of her right knee. Her medications include ramipril, naproxen, ranitidine, and a multivitamin. She appears cyanotic, drowsy, and is oriented only to person. Her temperature is 38.6°C (101.5°F), pulse is 135/minute, respirations are 36/min, and blood pressure is 155/95 mm Hg. Pulse oximetry on room air shows an oxygen saturation of 81%. There are several scattered petechiae on the anterior chest wall. Laboratory studies show a hemoglobin concentration of 10.5 g/dL, a leukocyte count of 9,000/mm3, a platelet count of 145,000/mm3, and a creatine kinase of 190 U/L. An ECG shows sinus tachycardia. What is the most likely diagnosis?','answer': 'Fat embolism', 'rationale': 'The patient had a surgical repair of a displaced femoral neck fracture. The patient has petechiae. The patient has a new oxygen requirement, meaning they are having difficulty with their breathing. This patient most likely has a fat embolism'}, {'question': 'A 55-year-old man comes to the emergency department because of a dry cough and severe chest pain beginning that morning. Two months ago, he was diagnosed with inferior wall myocardial infarction and was treated with stent implantation of the right coronary artery. He has a history of hypertension and hypercholesterolemia. His medications include aspirin, clopidogrel, atorvastatin, and enalapril. His temperature is 38.5°C (101.3°F), pulse is 92/min, respirations are 22/min, and blood pressure is 130/80 mm Hg. Cardiac examination shows a high-pitched scratching sound best heard while sitting upright and during expiration. The remainder of the examination shows no abnormalities. An ECG shows diffuse ST elevations. Serum studies show a troponin I of 0.005 ng/mL (N < 0.01). What is the most likely cause of this patients symptoms?', 'answer': 'Dresslers Syndrome', 'rationale': 'This patient is having chest pain. They recently had a heart attack and have new chest pain, suggestig he may have a problem with his heart. The EKG has diffuse ST elevations and scratching murmur. This patient likely has Dressler Syndrome.'}]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1hZTHlsz-lAJ" }, "outputs": [], "source": [ "Rationale = dsp.Type(\n", " prefix=\"Rationale: a step-by-step deduction that identifies the correct response.\",\n", " desc=\"${a step-by-step deduction that identifies the correct response}\"\n", ")\n", "Question = dsp.Type(prefix=\"Question:\", desc=\"${the question to be answered}\")\n", "Answer = dsp.Type(prefix=\"Answer:\", desc=\"${a short factoid answer, often between 1 and 5 words}\", format=dsp.format_answers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ShGb2oWQ_yFI" }, "outputs": [], "source": [ "qa_template_with_CoT = dsp.Template(\n", " instructions=qa_template.instructions, question=Question(), rationale=Rationale(), answer=Answer()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vkStrPn6IlmO" }, "outputs": [], "source": [ "diag_q['SC'] = np.nan\n", "diag_q['SC__rationale'] = np.nan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uWC9Sb0fHRyJ" }, "outputs": [], "source": [ "@dsp.transformation\n", "def QA_predict(example: dsp.Example, sc=True):\n", " rats = [ ]\n", " if sc == True:\n", " example, completions = dsp.generate(qa_template_with_CoT, n=23, temperature=0.7)(example, stage='qa')\n", "\n", " for i in range(23):\n", " try:\n", " rats.append(completions[i].rationale + '__')\n", " except IndexError:\n", " break\n", " # print(rats)\n", " completions_total = dsp.majority(completions)\n", " else:\n", " example, completions = dsp.generate(qa_template_with_CoT)(example, stage='qa')\n", " #print(completions.rationale)\n", " return example.copy(answer=completions_total.answer, rationale = rats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xrxpoZ3gHMGV" }, "outputs": [], "source": [ "def vanilla_LM_QA_FewCoT_SC(question: str) -> str:\n", " example = dsp.Example(question=question, demos=demo_example_CoT)\n", " return QA_predict(example, sc=True).answer, QA_predict(example, sc=True).rationale" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q-NudypAFHxs" }, "outputs": [], "source": [ "for i in range(519):\n", " print(i)\n", " a = vanilla_LM_QA_FewCoT_SC(diag_q['question'][i])\n", " diag_q['SC'][i] = a[0]\n", " diag_q['SC__rationale'][i] = a[1]\n", " diag_q.to_csv('/content/drive/MyDrive/file1.csv', index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pD6DkpRpIcu9" }, "outputs": [], "source": [ "diag_q.to_csv('SC.csv', index=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "qv70Z9GgDbnZ" }, "source": [ "**Clinical Reasoning Prompts**\n" ] }, { "cell_type": "markdown", "metadata": { "id": "hG-6ZWvPojV6" }, "source": [ "**Clinical Reasoning**\n", "\n", "Here we use a modified CoT for Self-consistency. Below are the different rationales for the different types of questions:\n", "- Differential Diagnosis\n", "- Intuitive Reasoning\n", "- Analytic Reasoning\n", "- Bayesian Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "swm8oGuEdrMW" }, "outputs": [], "source": [ "diag_q['diag_IR'] = np.nan\n", "diag_q['diag_IR_rationale'] = np.nan\n", "diag_q['IR_numb'] = np.nan\n", "diag_q['diag_AR'] = np.nan\n", "diag_q['diag_AR_rationale'] = np.nan\n", "diag_q['AR_numb'] = np.nan\n", "diag_q['diag_BR'] = np.nan\n", "diag_q['diag_BR_rationale'] = np.nan\n", "diag_q['BR_numb'] = np.nan\n", "diag_q['diag_DR'] = np.nan\n", "diag_q['diag_DR_rationale'] = np.nan\n", "diag_q['DR_numb'] = np.nan\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a-PtKbdqPl-s" }, "outputs": [], "source": [ "Context = dsp.Type(\n", " prefix=\"Context:\\n\",\n", " desc=\"${sources that may contain relevant content}\",\n", " format=dsp.passages2text\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "bexIL2gYam88" }, "source": [ "Examples for each type of reasoning:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-wr4RboQaqIo" }, "outputs": [], "source": [ "demo_example_DR = [{'question':'Shortly after undergoing a bipolar prosthesis for a displaced femoral neck fracture of the left hip acquired after a fall the day before, an 80-year-old woman suddenly develops dyspnea. The surgery under general anesthesia with sevoflurane was uneventful, lasting 98 minutes, during which the patient maintained oxygen saturation readings of 96–100% on 8 L of oxygen. She has a history of hypertension, osteoporosis, and osteoarthritis of her right knee. Her medications include ramipril, naproxen, ranitidine, and a multivitamin. She appears cyanotic, drowsy, and is oriented only to person. Her temperature is 38.6°C (101.5°F), pulse is 135/minute, respirations are 36/min, and blood pressure is 155/95 mm Hg. Pulse oximetry on room air shows an oxygen saturation of 81%. There are several scattered petechiae on the anterior chest wall. Laboratory studies show a hemoglobin concentration of 10.5 g/dL, a leukocyte count of 9,000/mm3, a platelet count of 145,000/mm3, and a creatine kinase of 190 U/L. An ECG shows sinus tachycardia. What is the most likely diagnosis?','answer': 'Fat embolism', 'rationale': 'This patient has shortness of breath after a long bone surgery.  The differential for this patient is pulmonary embolism, fat embolism, myocardial infarction, blood loss, anaphylaxis, or a drug reaction.  The patient has petechiae which makes fat embolism more likely.  This patient most likely has a fat embolism.'}, {'question': 'A 55-year-old man comes to the emergency department because of a dry cough and severe chest pain beginning that morning. Two months ago, he was diagnosed with inferior wall myocardial infarction and was treated with stent implantation of the right coronary artery. He has a history of hypertension and hypercholesterolemia. His medications include aspirin, clopidogrel, atorvastatin, and enalapril. His temperature is 38.5°C (101.3°F), pulse is 92/min, respirations are 22/min, and blood pressure is 130/80 mm Hg. Cardiac examination shows a high-pitched scratching sound best heard while sitting upright and during expiration. The remainder of the examination shows no abnormalities. An ECG shows diffuse ST elevations. Serum studies show a troponin I of 0.005 ng/mL (N < 0.01). What is the most likely cause of this patients symptoms?', 'answer': 'Dresslers Syndrome', 'rationale': 'This patient has chest pain with diffuse ST elevations after a recent myocardial infarction.  The differential for this patient includes: myocardial infarction, pulmonary embolism, pericarditis, Dressler syndrome, aortic dissection, and costochondritis.   This patient likely has a high-pitched scratching sound on auscultation associated with pericarditis and Dressler Syndrome.  This patient has diffuse ST elevations associated with Dressler Syndrome.  This patient most likely has Dressler Syndrome.'}]\n", "demo_example_IR = [{'question':'Shortly after undergoing a bipolar prosthesis for a displaced femoral neck fracture of the left hip acquired after a fall the day before, an 80-year-old woman suddenly develops dyspnea. The surgery under general anesthesia with sevoflurane was uneventful, lasting 98 minutes, during which the patient maintained oxygen saturation readings of 96–100% on 8 L of oxygen. She has a history of hypertension, osteoporosis, and osteoarthritis of her right knee. Her medications include ramipril, naproxen, ranitidine, and a multivitamin. She appears cyanotic, drowsy, and is oriented only to person. Her temperature is 38.6°C (101.5°F), pulse is 135/minute, respirations are 36/min, and blood pressure is 155/95 mm Hg. Pulse oximetry on room air shows an oxygen saturation of 81%. There are several scattered petechiae on the anterior chest wall. Laboratory studies show a hemoglobin concentration of 10.5 g/dL, a leukocyte count of 9,000/mm3, a platelet count of 145,000/mm3, and a creatine kinase of 190 U/L. An ECG shows sinus tachycardia. What is the most likely diagnosis?','answer': 'Fat embolism', 'rationale': 'This patient has findings of petechiae, altered mental status, shortness of breath, and recent surgery suggesting a diagnosis of fat emboli.  The patient most likely has a fat embolism. '}, {'question': 'A 55-year-old man comes to the emergency department because of a dry cough and severe chest pain beginning that morning. Two months ago, he was diagnosed with inferior wall myocardial infarction and was treated with stent implantation of the right coronary artery. He has a history of hypertension and hypercholesterolemia. His medications include aspirin, clopidogrel, atorvastatin, and enalapril. His temperature is 38.5°C (101.3°F), pulse is 92/min, respirations are 22/min, and blood pressure is 130/80 mm Hg. Cardiac examination shows a high-pitched scratching sound best heard while sitting upright and during expiration. The remainder of the examination shows no abnormalities. An ECG shows diffuse ST elevations. Serum studies show a troponin I of 0.005 ng/mL (N < 0.01). What is the most likely cause of this patients symptoms?', 'answer': 'Dresslers Syndrome', 'rationale': 'This patient had a recent myocardial infarction with new development of diffuse ST elevations, chest pain, and a high pitched scratching murmur which are found in Dresslers syndrome.   This patient likely has Dresslers Syndrome.'}]\n", "demo_example_AR = [{'question':'Shortly after undergoing a bipolar prosthesis for a displaced femoral neck fracture of the left hip acquired after a fall the day before, an 80-year-old woman suddenly develops dyspnea. The surgery under general anesthesia with sevoflurane was uneventful, lasting 98 minutes, during which the patient maintained oxygen saturation readings of 96–100% on 8 L of oxygen. She has a history of hypertension, osteoporosis, and osteoarthritis of her right knee. Her medications include ramipril, naproxen, ranitidine, and a multivitamin. She appears cyanotic, drowsy, and is oriented only to person. Her temperature is 38.6°C (101.5°F), pulse is 135/minute, respirations are 36/min, and blood pressure is 155/95 mm Hg. Pulse oximetry on room air shows an oxygen saturation of 81%. There are several scattered petechiae on the anterior chest wall. Laboratory studies show a hemoglobin concentration of 10.5 g/dL, a leukocyte count of 9,000/mm3, a platelet count of 145,000/mm3, and a creatine kinase of 190 U/L. An ECG shows sinus tachycardia. What is the most likely diagnosis?','answer': 'Fat embolism', 'rationale': 'The patient recently had large bone surgery making fat emboli a potential cause because the bone marrow was manipulated.  Petechiae can form in response to capillary inflammation caused by fat emboli.  Fat micro globules cause CNS microcirculation occlusion causing confusion and altered mental status.  Fat obstruction in the pulmonary arteries can cause tachycardia and shortness of breath as seen in this patient.   This patient most likely has a fat embolism.'}, {'question': 'A 55-year-old man comes to the emergency department because of a dry cough and severe chest pain beginning that morning. Two months ago, he was diagnosed with inferior wall myocardial infarction and was treated with stent implantation of the right coronary artery. He has a history of hypertension and hypercholesterolemia. His medications include aspirin, clopidogrel, atorvastatin, and enalapril. His temperature is 38.5°C (101.3°F), pulse is 92/min, respirations are 22/min, and blood pressure is 130/80 mm Hg. Cardiac examination shows a high-pitched scratching sound best heard while sitting upright and during expiration. The remainder of the examination shows no abnormalities. An ECG shows diffuse ST elevations. Serum studies show a troponin I of 0.005 ng/mL (N < 0.01). What is the most likely cause of this patients symptoms?', 'answer': 'Dresslers Syndrome', 'rationale': 'This patient had a recent myocardial infarction which can cause myocardial inflammation that causes pericarditis and Dressler Syndrome.  The diffuse ST elevations and high pitched scratching murmur can be signs of pericardial inflammation as the inflamed pericardium rubs against the pleura as seen with Dressler Syndrome.  This patient likely has Dressler Syndrome.'}]\n", "demo_example_BR = [{'question':'Shortly after undergoing a bipolar prosthesis for a displaced femoral neck fracture of the left hip acquired after a fall the day before, an 80-year-old woman suddenly develops dyspnea. The surgery under general anesthesia with sevoflurane was uneventful, lasting 98 minutes, during which the patient maintained oxygen saturation readings of 96–100% on 8 L of oxygen. She has a history of hypertension, osteoporosis, and osteoarthritis of her right knee. Her medications include ramipril, naproxen, ranitidine, and a multivitamin. She appears cyanotic, drowsy, and is oriented only to person. Her temperature is 38.6°C (101.5°F), pulse is 135/minute, respirations are 36/min, and blood pressure is 155/95 mm Hg. Pulse oximetry on room air shows an oxygen saturation of 81%. There are several scattered petechiae on the anterior chest wall. Laboratory studies show a hemoglobin concentration of 10.5 g/dL, a leukocyte count of 9,000/mm3, a platelet count of 145,000/mm3, and a creatine kinase of 190 U/L. An ECG shows sinus tachycardia. What is the most likely diagnosis?','answer': 'Fat embolism', 'rationale': 'The prior probability of fat embolism is 0.05% however the patient has petechiae on exam which is seen with fat emboli, which increases the posterior probability of fat embolism to 5%.  Altered mental status increases the probability further to 10%.  Recent orthopedic surgery increases the probability of fat emboli syndrome to 60%.  This patient most likely has a fat embolism. '}, {'question': 'A 55-year-old man comes to the emergency department because of a dry cough and severe chest pain beginning that morning. Two months ago, he was diagnosed with inferior wall myocardial infarction and was treated with stent implantation of the right coronary artery. He has a history of hypertension and hypercholesterolemia. His medications include aspirin, clopidogrel, atorvastatin, and enalapril. His temperature is 38.5°C (101.3°F), pulse is 92/min, respirations are 22/min, and blood pressure is 130/80 mm Hg. Cardiac examination shows a high-pitched scratching sound best heard while sitting upright and during expiration. The remainder of the examination shows no abnormalities. An ECG shows diffuse ST elevations. Serum studies show a troponin I of 0.005 ng/mL (N < 0.01). What is the most likely cause of this patients symptoms?', 'answer': 'Dresslers Syndrome', 'rationale': 'The prior probability of Dressler Syndrome is 0.01%. The patient has diffuse ST elevations, increasing the probability of Dressler Syndrome to 5%.  The patient has a scratching murmur which increases the probability to 10%.  In the setting of a recent MI the posterior probability of myocardial infarction is 55%.  This patient likely has Dressler Syndrome.'}]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zCtvh0zUpOKV" }, "outputs": [], "source": [ "# Intuitive Rationale\n", "Rationale_diag_IR = dsp.Type(\n", " prefix=\"Rationale: a step-by-step deduction that identifies the correct response\",\n", " desc=\"${ollow the following steps: 1) list a broad differential of 6 diagnoses that answer the question. 2) reference the question to find patient information or test results that make certain diagnoses on the differential more likely. 3) Narrow the differential to 3 diagnoses based. 4) Again reference the question to find information that makes one diagnosis more likely. 6) Answer with the most likely diagnosis. }\")\n", "\n", "# Differential Rationale\n", "Rationale_diag_DR = dsp.Type(\n", " prefix=\"Rationale: a step-by-step deduction that identifies the correct response\",\n", " desc=\"${Follow the following steps: 1) list a broad differential of 6 diagnoses that answer the question. 2) reference the question to find patient information or test results that make certain diagnoses on the differential more likely. 3) Narrow the differential to 3 diagnoses based. 4) Again reference the question to find information that makes one diagnosis more likely. 6) Answer with the most likely diagnosis\")\n", "\n", "# Analytic Rationale\n", "Rationale_diag_AR = dsp.Type(\n", " prefix=\"Rationale: a step-by-step deduction that identifies the correct response.\",\n", " desc=\"${Create a differential diagnosis, then use analytic reasoning to deduce the physiologic or biochemical pathophysiology of the patient and identify the correct response.}\")\n", "\n", "# Bayesian Rationale\n", "Rationale_diag_BR = dsp.Type(\n", " prefix=\"Rationale: a step-by-step deduction that identifies the correct response.\",\n", " desc=\"${Create a differential diagnosis, then use a step-by-step Bayesian inference to deduce the correct response. }\")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LHSqW8pQpXzr" }, "outputs": [], "source": [ " qa_template_diag_IR = dsp.Template(\n", " instructions=qa_template.instructions, question=Question(), rationale=Rationale_diag_IR(), answer=Answer()\n", ")\n", "\n", " qa_template_diag_AR = dsp.Template(\n", " instructions=qa_template.instructions, question=Question(), rationale=Rationale_diag_AR(), answer=Answer()\n", ")\n", "\n", "\n", " qa_template_diag_BR = dsp.Template(\n", " instructions=qa_template.instructions,question=Question(), rationale=Rationale_diag_BR(), answer=Answer()\n", ")\n", "\n", " qa_template_diag_DR = dsp.Template(\n", " instructions=qa_template.instructions,question=Question(), rationale=Rationale_diag_DR(), answer=Answer()\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sMzkPPNoprCj" }, "outputs": [], "source": [ "@dsp.transformation\n", "def QA_predict_IR(example: dsp.Example, sc=True):\n", " rats = [ ]\n", " if sc ==True :\n", " example, completions = dsp.generate(qa_template_diag_IR, n=23, temperature=0.7)(example, stage='qa')\n", " runs = len(completions)\n", " for i in range(len(completions)):\n", " rats.append(completions[i].rationale + '__')\n", " completions_total = dsp.majority(completions)\n", " else:\n", " example, completions = dsp.generate(qa_template_diag_IR)(example, stage='qa')\n", " return example.copy(answer=completions_total.answer, rationale = rats, run_numb = runs)\n", "\n", "def QA_predict_AR(example: dsp.Example, sc=True):\n", " rats = [ ]\n", " if sc ==True :\n", " example, completions = dsp.generate(qa_template_diag_AR, n=23, temperature=0.7)(example, stage='qa')\n", " runs = len(completions)\n", " for i in range(len(completions)):\n", " rats.append(completions[i].rationale + '__')\n", " completions_total = dsp.majority(completions)\n", " else:\n", " example, completions = dsp.generate(qa_template_diag_AR)(example, stage='qa')\n", " return example.copy(answer=completions_total.answer, rationale = rats, run_numb = runs)\n", "\n", "def QA_predict_BR(example: dsp.Example, sc=True):\n", " rats = [ ]\n", " if sc ==True :\n", " example, completions = dsp.generate(qa_template_diag_BR, n=23, temperature=0.7)(example, stage='qa')\n", " runs = len(completions)\n", " for i in range(len(completions)):\n", " rats.append(completions[i].rationale + '__')\n", " completions_total = dsp.majority(completions)\n", " else:\n", " example, completions = dsp.generate(qa_template_diag_BR)(example, stage='qa')\n", " return example.copy(answer=completions_total.answer, rationale = rats, run_numb = runs)\n", "\n", "def QA_predict_DR(example: dsp.Example, sc=True):\n", " rats = [ ]\n", " if sc ==True :\n", " example, completions = dsp.generate(qa_template_diag_DR, n=23, temperature=0.7)(example, stage='qa')\n", " runs = len(completions)\n", " for i in range(len(completions)):\n", " rats.append(completions[i].rationale + '__')\n", " completions_total = dsp.majority(completions)\n", " else:\n", " example, completions = dsp.generate(qa_template_diag_DR)(example, stage='qa')\n", " return example.copy(answer=completions_total.answer, rationale = rats, run_numb = runs)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ECgYlAZtqNCe" }, "outputs": [], "source": [ "import random\n", "def retrieve_then_QA_IR(question: str, selected) -> str:\n", " demos = demo_example_IR\n", " example = dsp.Example(question=question, demos=demos)\n", " return QA_predict_IR(example, sc=True).answer, QA_predict_IR(example, sc=True).rationale, QA_predict_IR(example, sc=True).run_numb\n", "\n", "\n", "def retrieve_then_QA_AR(question: str, selected) -> str:\n", " demos = demo_example_AR\n", " example = dsp.Example(question=question, context=passages, demos=demos)\n", "\n", " return QA_predict_AR(example, sc=True).answer, QA_predict_AR(example, sc=True).rationale, QA_predict_AR(example, sc=True).run_numb\n", "\n", "def retrieve_then_QA_BR(question: str, selected) -> str:\n", " demos = demo_example_BR\n", " example = dsp.Example(question=question, context=passages, demos=demos)\n", "\n", " return QA_predict_BR(example, sc=True).answer, QA_predict_BR(example, sc=True).rationale, QA_predict_BR(example, sc=True).run_numb\n", "\n", "\n", "def retrieve_then_QA_DR(question: str, selected) -> str:\n", " demos = demo_example_DR\n", " example = dsp.Example(question=question, context=passages, demos=demos)\n", "\n", " return QA_predict_DR(example, sc=True).answer, QA_predict_DR(example, sc=True).rationale, QA_predict_DR(example, sc=True).run_numb\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M1CNPKNKqzDx" }, "outputs": [], "source": [ "for i in range(519):\n", " print(i)\n", " a = retrieve_then_QA_IR(diag_q['question'][i],selected)\n", " diag_q['diag_IR'][i] = a[0]\n", " diag_q['diag_IR_rationale'][i] = a[1]\n", " diag_q['IR_numb'][i] = a[2]\n", "\n", " b = retrieve_then_QA_DR(diag_q['question'][i],selected)\n", " diag_q['diag_DR5'][i] = b[0]\n", " diag_q['diag_DR_rationale'][i] = b[1]\n", " diag_q['DR_numb'][i] = b[2]\n", "\n", " c = retrieve_then_QA_AR(diag_q['question'][i],selected)\n", " diag_q['diag_AR6'][i] = c[0]\n", " diag_q['diag_AR_rationale'][i] = c[1]\n", " diag_q['AR_numb'][i] = c[2]\n", "\n", " d = retrieve_then_QA_BR(diag_q['question'][i],selected)\n", " diag_q['diag_BR4'][i] = d[0]\n", " diag_q['diag_BR_rationale'][i] = d[1]\n", " diag_q['BR_numb'][i] = d[2]\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i8rwC1zftLQZ" }, "outputs": [], "source": [ "diag_q.to_csv('Test_Results.csv', index=False)\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "machine_shape": "hm", "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }