Source code for prediction of COVID-19 test results. This is supplemental material to publication
Wojtusiak J, Bagais W, Vang J, Guralnik E, Roess A, Alemi F, "The Role of Symptom Clusters in Triage of COVID-19 Patients," Quality Management in Health Care, 2022.
Source code by Wejdan Bagais and Jee Vang with contribution of other authors.
# import libraries
from models import select_attributes
from models import cut_hierarchy_columns
import pandas as pd
import numpy as np
import timeit
import pickle
import matplotlib as mpl
import scipy.cluster.hierarchy as shc
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import fcluster
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import numpy as np
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
start = timeit.default_timer()
# list of values for inverse of regularization strength
c_list = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.5,2]
split_ids = []
cluster_n = []
cs = []
source = []
AUCs = []
prec = []
rec = []
vars_cnt = []
vars_lists = []
ys_test = []
ys_pred = []
for i in range (0,30):
# read one of the 30 split data
tr_path = "../data/30_splits_data/binary-transformed_tr_"+str(i)+".csv"
ts_path = "../data/30_splits_data/binary-transformed_ts_"+str(i)+".csv"
train = pd.read_csv(tr_path)
test = pd.read_csv(ts_path)
# Select the list of original predictors
symptoms = train.columns.tolist()
symptoms.remove('TestPositive')
# For each cutoff point run tehe lasso model
for cc in range(2,len(symptoms)):
# create the hierarchical clusters based on the selected cutoff point
XT, Xt, yT, yt = cut_hierarchy_columns(train, test, cc)
# build the model for each value fo the inverse of regularization strength
for c in c_list:
auc, recall, precision, valid_cols, y_test, y_pred = select_attributes(XT, yT, Xt, yt, c)
# add results to the list
split_ids.append(i)
cluster_n.append(cc)
cs.append(c)
source.append('cut_hierarchy')
AUCs.append(auc)
prec.append(precision)
rec.append(recall)
vars_cnt.append(len(valid_cols))
vars_lists.append(valid_cols)
if(y_test is not np.nan):
ys_test.append(yt.values.tolist())
ys_pred.append(y_pred.tolist())
else:
ys_test.append(np.nan)
ys_pred.append(np.nan)
print(f'ID {i}, number of clusters={cc:02}, C={c:.2f}, AUC={auc:.5f}, Precision={precision:.5f}, Recall={recall:.5f}, cls# {len(valid_cols)}')
stop = timeit.default_timer()
print('Time: ', stop - start)
# create dataframe for all results
ff = pd.DataFrame({'split_ids' : split_ids,
'cluster_n' : cluster_n,
'cs' : cs,
'source' : source,
'AUCs' : AUCs,
'prec' : prec,
'rec' : rec,
'vars_cnt' : vars_cnt,
'vars_lists' : vars_lists,
'y_test' : ys_test,
'y_pred' : ys_pred
})
# identify the list of unique selected predictors
unq_var = ff['vars_lists'].values.tolist()
for i in range(0, len(unq_var)):
unq_var[i] = [sub.replace(':', ',') for sub in unq_var[i]]
sympt_lists = []
for i in range(0, len(unq_var)):
l = ",".join(unq_var[i])
l2 = list(set(l.split(',')))
sympt_lists.append(l2)
# Add the list of unique predictors and its count to the dataframe
ff['sympt_lists'] = sympt_lists
ff['sympt_cnt'] = ff['sympt_lists'].apply(lambda x :len(x))
# save the results
ff.to_csv('../data/results/hierarchical_model.csv', index=False)
table = pd.pivot_table(ff, values=['AUCs', 'vars_cnt', 'sympt_cnt']
, index=['cluster_n', 'cs']
, aggfunc=np.mean).round(decimals=4)
table['sympt_cnt'] = table['sympt_cnt'].round().astype(int)
table['vars_cnt'] = table['vars_cnt'].round().astype(int)
pd.set_option('display.max_rows', None)
table
AUCs | sympt_cnt | vars_cnt | ||
---|---|---|---|---|
cluster_n | cs | |||
2 | 0.1 | 0.5385 | 1 | 0 |
0.2 | 0.5468 | 1 | 0 | |
0.3 | 0.5482 | 1 | 1 | |
0.4 | 0.5331 | 1 | 1 | |
0.5 | 0.5289 | 1 | 1 | |
0.6 | 0.5291 | 1 | 1 | |
0.7 | 0.5291 | 1 | 1 | |
0.8 | 0.5305 | 1 | 1 | |
0.9 | 0.5316 | 1 | 1 | |
1.0 | 0.5305 | 1 | 1 | |
1.5 | 0.5308 | 2 | 2 | |
2.0 | 0.5282 | 2 | 2 | |
3 | 0.1 | 0.5730 | 1 | 1 |
0.2 | 0.6236 | 2 | 2 | |
0.3 | 0.6204 | 2 | 2 | |
0.4 | 0.6185 | 2 | 2 | |
0.5 | 0.6233 | 2 | 2 | |
0.6 | 0.6251 | 3 | 3 | |
0.7 | 0.6245 | 3 | 3 | |
0.8 | 0.6256 | 3 | 3 | |
0.9 | 0.6275 | 3 | 3 | |
1.0 | 0.6294 | 3 | 3 | |
1.5 | 0.6336 | 3 | 3 | |
2.0 | 0.6365 | 3 | 3 | |
4 | 0.1 | 0.6517 | 1 | 1 |
0.2 | 0.7060 | 2 | 2 | |
0.3 | 0.7167 | 3 | 3 | |
0.4 | 0.7169 | 4 | 4 | |
0.5 | 0.7188 | 4 | 4 | |
0.6 | 0.7220 | 4 | 4 | |
0.7 | 0.7224 | 4 | 4 | |
0.8 | 0.7244 | 4 | 4 | |
0.9 | 0.7249 | 4 | 4 | |
1.0 | 0.7246 | 4 | 4 | |
1.5 | 0.7240 | 5 | 5 | |
2.0 | 0.7247 | 5 | 5 | |
5 | 0.1 | 0.7298 | 2 | 2 |
0.2 | 0.7585 | 3 | 3 | |
0.3 | 0.7612 | 3 | 3 | |
0.4 | 0.7620 | 4 | 4 | |
0.5 | 0.7655 | 4 | 4 | |
0.6 | 0.7674 | 4 | 4 | |
0.7 | 0.7659 | 4 | 4 | |
0.8 | 0.7663 | 5 | 5 | |
0.9 | 0.7648 | 5 | 5 | |
1.0 | 0.7643 | 5 | 5 | |
1.5 | 0.7611 | 5 | 5 | |
2.0 | 0.7611 | 5 | 5 | |
6 | 0.1 | 0.7287 | 2 | 2 |
0.2 | 0.7557 | 4 | 4 | |
0.3 | 0.7601 | 4 | 4 | |
0.4 | 0.7636 | 5 | 5 | |
0.5 | 0.7655 | 5 | 5 | |
0.6 | 0.7665 | 5 | 5 | |
0.7 | 0.7649 | 6 | 6 | |
0.8 | 0.7648 | 6 | 6 | |
0.9 | 0.7651 | 6 | 6 | |
1.0 | 0.7649 | 6 | 6 | |
1.5 | 0.7622 | 6 | 6 | |
2.0 | 0.7616 | 6 | 6 | |
7 | 0.1 | 0.7284 | 2 | 2 |
0.2 | 0.7491 | 4 | 4 | |
0.3 | 0.7569 | 5 | 5 | |
0.4 | 0.7603 | 6 | 6 | |
0.5 | 0.7616 | 6 | 6 | |
0.6 | 0.7639 | 6 | 6 | |
0.7 | 0.7632 | 6 | 6 | |
0.8 | 0.7644 | 6 | 6 | |
0.9 | 0.7657 | 7 | 7 | |
1.0 | 0.7658 | 7 | 7 | |
1.5 | 0.7665 | 7 | 7 | |
2.0 | 0.7659 | 7 | 7 | |
8 | 0.1 | 0.7321 | 2 | 2 |
0.2 | 0.7499 | 5 | 5 | |
0.3 | 0.7587 | 6 | 6 | |
0.4 | 0.7607 | 6 | 6 | |
0.5 | 0.7627 | 7 | 7 | |
0.6 | 0.7634 | 7 | 7 | |
0.7 | 0.7652 | 7 | 7 | |
0.8 | 0.7670 | 7 | 7 | |
0.9 | 0.7684 | 7 | 7 | |
1.0 | 0.7681 | 7 | 7 | |
1.5 | 0.7700 | 8 | 8 | |
2.0 | 0.7705 | 8 | 8 | |
9 | 0.1 | 0.7355 | 2 | 2 |
0.2 | 0.7513 | 5 | 5 | |
0.3 | 0.7636 | 7 | 7 | |
0.4 | 0.7664 | 7 | 7 | |
0.5 | 0.7680 | 8 | 8 | |
0.6 | 0.7687 | 8 | 8 | |
0.7 | 0.7701 | 8 | 8 | |
0.8 | 0.7706 | 8 | 8 | |
0.9 | 0.7710 | 8 | 8 | |
1.0 | 0.7703 | 8 | 8 | |
1.5 | 0.7724 | 9 | 9 | |
2.0 | 0.7725 | 9 | 9 | |
10 | 0.1 | 0.7479 | 3 | 3 |
0.2 | 0.7640 | 6 | 6 | |
0.3 | 0.7744 | 7 | 7 | |
0.4 | 0.7749 | 8 | 8 | |
0.5 | 0.7758 | 8 | 8 | |
0.6 | 0.7765 | 8 | 8 | |
0.7 | 0.7760 | 9 | 9 | |
0.8 | 0.7772 | 9 | 9 | |
0.9 | 0.7773 | 9 | 9 | |
1.0 | 0.7784 | 9 | 9 | |
1.5 | 0.7789 | 10 | 10 | |
2.0 | 0.7797 | 10 | 10 | |
11 | 0.1 | 0.7732 | 4 | 4 |
0.2 | 0.7824 | 7 | 7 | |
0.3 | 0.7860 | 8 | 8 | |
0.4 | 0.7846 | 9 | 9 | |
0.5 | 0.7863 | 9 | 9 | |
0.6 | 0.7864 | 10 | 10 | |
0.7 | 0.7861 | 10 | 10 | |
0.8 | 0.7859 | 10 | 10 | |
0.9 | 0.7865 | 11 | 11 | |
1.0 | 0.7872 | 11 | 11 | |
1.5 | 0.7873 | 11 | 11 | |
2.0 | 0.7879 | 12 | 12 | |
12 | 0.1 | 0.7731 | 4 | 4 |
0.2 | 0.7830 | 7 | 7 | |
0.3 | 0.7851 | 8 | 8 | |
0.4 | 0.7847 | 9 | 9 | |
0.5 | 0.7855 | 10 | 10 | |
0.6 | 0.7872 | 10 | 10 | |
0.7 | 0.7869 | 11 | 11 | |
0.8 | 0.7869 | 11 | 11 | |
0.9 | 0.7864 | 11 | 11 | |
1.0 | 0.7880 | 12 | 12 | |
1.5 | 0.7869 | 12 | 12 | |
2.0 | 0.7875 | 12 | 12 | |
13 | 0.1 | 0.7734 | 4 | 4 |
0.2 | 0.7832 | 7 | 7 | |
0.3 | 0.7825 | 8 | 8 | |
0.4 | 0.7832 | 9 | 9 | |
0.5 | 0.7844 | 10 | 10 | |
0.6 | 0.7846 | 11 | 11 | |
0.7 | 0.7853 | 11 | 11 | |
0.8 | 0.7855 | 11 | 11 | |
0.9 | 0.7848 | 12 | 12 | |
1.0 | 0.7858 | 12 | 12 | |
1.5 | 0.7861 | 13 | 13 | |
2.0 | 0.7858 | 13 | 13 | |
14 | 0.1 | 0.7747 | 4 | 4 |
0.2 | 0.7815 | 7 | 7 | |
0.3 | 0.7821 | 8 | 8 | |
0.4 | 0.7842 | 9 | 9 | |
0.5 | 0.7871 | 10 | 10 | |
0.6 | 0.7883 | 11 | 11 | |
0.7 | 0.7889 | 11 | 11 | |
0.8 | 0.7891 | 11 | 11 | |
0.9 | 0.7887 | 12 | 12 | |
1.0 | 0.7887 | 12 | 12 | |
1.5 | 0.7876 | 13 | 13 | |
2.0 | 0.7855 | 13 | 13 | |
15 | 0.1 | 0.7743 | 4 | 4 |
0.2 | 0.7817 | 7 | 7 | |
0.3 | 0.7839 | 8 | 8 | |
0.4 | 0.7858 | 10 | 10 | |
0.5 | 0.7885 | 10 | 10 | |
0.6 | 0.7898 | 11 | 11 | |
0.7 | 0.7891 | 11 | 11 | |
0.8 | 0.7885 | 12 | 12 | |
0.9 | 0.7887 | 12 | 12 | |
1.0 | 0.7884 | 12 | 12 | |
1.5 | 0.7865 | 13 | 13 | |
2.0 | 0.7843 | 14 | 14 | |
16 | 0.1 | 0.7747 | 4 | 4 |
0.2 | 0.7826 | 7 | 7 | |
0.3 | 0.7828 | 8 | 8 | |
0.4 | 0.7853 | 10 | 10 | |
0.5 | 0.7877 | 10 | 10 | |
0.6 | 0.7885 | 11 | 11 | |
0.7 | 0.7878 | 12 | 12 | |
0.8 | 0.7880 | 12 | 12 | |
0.9 | 0.7877 | 12 | 12 | |
1.0 | 0.7864 | 13 | 13 | |
1.5 | 0.7856 | 14 | 14 | |
2.0 | 0.7828 | 15 | 15 | |
17 | 0.1 | 0.7749 | 4 | 4 |
0.2 | 0.7826 | 7 | 7 | |
0.3 | 0.7823 | 8 | 8 | |
0.4 | 0.7854 | 10 | 10 | |
0.5 | 0.7872 | 11 | 11 | |
0.6 | 0.7871 | 11 | 11 | |
0.7 | 0.7860 | 12 | 12 | |
0.8 | 0.7859 | 12 | 12 | |
0.9 | 0.7853 | 13 | 13 | |
1.0 | 0.7847 | 13 | 13 | |
1.5 | 0.7834 | 15 | 15 | |
2.0 | 0.7811 | 16 | 16 | |
18 | 0.1 | 0.7754 | 5 | 5 |
0.2 | 0.7821 | 7 | 7 | |
0.3 | 0.7805 | 9 | 9 | |
0.4 | 0.7858 | 10 | 10 | |
0.5 | 0.7877 | 11 | 11 | |
0.6 | 0.7872 | 12 | 12 | |
0.7 | 0.7872 | 12 | 12 | |
0.8 | 0.7870 | 13 | 13 | |
0.9 | 0.7864 | 14 | 14 | |
1.0 | 0.7858 | 14 | 14 | |
1.5 | 0.7853 | 15 | 15 | |
2.0 | 0.7827 | 16 | 16 | |
19 | 0.1 | 0.7757 | 5 | 5 |
0.2 | 0.7824 | 8 | 8 | |
0.3 | 0.7818 | 9 | 9 | |
0.4 | 0.7852 | 10 | 10 | |
0.5 | 0.7868 | 11 | 11 | |
0.6 | 0.7855 | 12 | 12 | |
0.7 | 0.7854 | 13 | 13 | |
0.8 | 0.7846 | 14 | 14 | |
0.9 | 0.7842 | 14 | 14 | |
1.0 | 0.7838 | 15 | 15 | |
1.5 | 0.7819 | 16 | 16 | |
2.0 | 0.7782 | 18 | 18 | |
20 | 0.1 | 0.7771 | 5 | 5 |
0.2 | 0.7840 | 8 | 8 | |
0.3 | 0.7843 | 9 | 9 | |
0.4 | 0.7882 | 10 | 10 | |
0.5 | 0.7896 | 11 | 11 | |
0.6 | 0.7894 | 13 | 13 | |
0.7 | 0.7886 | 13 | 13 | |
0.8 | 0.7872 | 14 | 14 | |
0.9 | 0.7866 | 15 | 15 | |
1.0 | 0.7851 | 15 | 15 | |
1.5 | 0.7813 | 17 | 17 | |
2.0 | 0.7784 | 19 | 19 | |
21 | 0.1 | 0.7754 | 5 | 5 |
0.2 | 0.7826 | 8 | 8 | |
0.3 | 0.7828 | 9 | 9 | |
0.4 | 0.7859 | 11 | 11 | |
0.5 | 0.7866 | 12 | 12 | |
0.6 | 0.7849 | 13 | 13 | |
0.7 | 0.7853 | 14 | 14 | |
0.8 | 0.7827 | 15 | 15 | |
0.9 | 0.7810 | 15 | 15 | |
1.0 | 0.7793 | 16 | 16 | |
1.5 | 0.7757 | 18 | 18 | |
2.0 | 0.7727 | 19 | 19 | |
22 | 0.1 | 0.7737 | 5 | 5 |
0.2 | 0.7812 | 8 | 8 | |
0.3 | 0.7813 | 10 | 10 | |
0.4 | 0.7808 | 11 | 11 | |
0.5 | 0.7798 | 12 | 12 | |
0.6 | 0.7788 | 13 | 13 | |
0.7 | 0.7772 | 14 | 14 | |
0.8 | 0.7740 | 15 | 15 | |
0.9 | 0.7730 | 16 | 16 | |
1.0 | 0.7715 | 16 | 16 | |
1.5 | 0.7687 | 18 | 18 | |
2.0 | 0.7641 | 20 | 20 | |
23 | 0.1 | 0.7740 | 5 | 5 |
0.2 | 0.7805 | 8 | 8 | |
0.3 | 0.7798 | 10 | 10 | |
0.4 | 0.7800 | 11 | 11 | |
0.5 | 0.7799 | 12 | 12 | |
0.6 | 0.7785 | 13 | 13 | |
0.7 | 0.7768 | 14 | 14 | |
0.8 | 0.7737 | 15 | 15 | |
0.9 | 0.7723 | 16 | 16 | |
1.0 | 0.7706 | 16 | 16 | |
1.5 | 0.7666 | 18 | 18 | |
2.0 | 0.7635 | 20 | 20 | |
24 | 0.1 | 0.7744 | 6 | 6 |
0.2 | 0.7809 | 8 | 8 | |
0.3 | 0.7808 | 10 | 10 | |
0.4 | 0.7806 | 11 | 11 | |
0.5 | 0.7807 | 12 | 12 | |
0.6 | 0.7792 | 13 | 13 | |
0.7 | 0.7769 | 14 | 14 | |
0.8 | 0.7746 | 15 | 15 | |
0.9 | 0.7726 | 16 | 16 | |
1.0 | 0.7712 | 16 | 16 | |
1.5 | 0.7672 | 18 | 18 | |
2.0 | 0.7627 | 20 | 20 | |
25 | 0.1 | 0.7784 | 5 | 5 |
0.2 | 0.7810 | 8 | 8 | |
0.3 | 0.7825 | 10 | 10 | |
0.4 | 0.7821 | 11 | 11 | |
0.5 | 0.7832 | 12 | 12 | |
0.6 | 0.7802 | 13 | 13 | |
0.7 | 0.7770 | 14 | 14 | |
0.8 | 0.7756 | 15 | 15 | |
0.9 | 0.7741 | 16 | 16 | |
1.0 | 0.7718 | 16 | 16 | |
1.5 | 0.7670 | 19 | 19 | |
2.0 | 0.7626 | 20 | 20 | |
26 | 0.1 | 0.7784 | 5 | 5 |
0.2 | 0.7812 | 8 | 8 | |
0.3 | 0.7837 | 10 | 10 | |
0.4 | 0.7841 | 11 | 11 | |
0.5 | 0.7850 | 12 | 12 | |
0.6 | 0.7825 | 13 | 13 | |
0.7 | 0.7795 | 14 | 14 | |
0.8 | 0.7772 | 15 | 15 | |
0.9 | 0.7752 | 16 | 16 | |
1.0 | 0.7745 | 16 | 16 | |
1.5 | 0.7700 | 18 | 18 | |
2.0 | 0.7668 | 20 | 20 | |
27 | 0.1 | 0.7793 | 5 | 5 |
0.2 | 0.7839 | 8 | 8 | |
0.3 | 0.7835 | 9 | 9 | |
0.4 | 0.7848 | 11 | 11 | |
0.5 | 0.7854 | 12 | 12 | |
0.6 | 0.7834 | 13 | 13 | |
0.7 | 0.7815 | 14 | 14 | |
0.8 | 0.7791 | 14 | 14 | |
0.9 | 0.7759 | 15 | 15 | |
1.0 | 0.7754 | 16 | 16 | |
1.5 | 0.7694 | 19 | 19 | |
2.0 | 0.7666 | 20 | 20 | |
28 | 0.1 | 0.7822 | 4 | 4 |
0.2 | 0.7856 | 8 | 8 | |
0.3 | 0.7866 | 9 | 9 | |
0.4 | 0.7861 | 11 | 11 | |
0.5 | 0.7861 | 12 | 12 | |
0.6 | 0.7838 | 13 | 13 | |
0.7 | 0.7810 | 14 | 14 | |
0.8 | 0.7788 | 15 | 15 | |
0.9 | 0.7772 | 16 | 16 | |
1.0 | 0.7755 | 16 | 16 | |
1.5 | 0.7698 | 19 | 19 | |
2.0 | 0.7684 | 21 | 21 | |
29 | 0.1 | 0.7814 | 4 | 4 |
0.2 | 0.7845 | 7 | 7 | |
0.3 | 0.7863 | 9 | 9 | |
0.4 | 0.7860 | 10 | 10 | |
0.5 | 0.7850 | 12 | 12 | |
0.6 | 0.7825 | 13 | 13 | |
0.7 | 0.7803 | 14 | 14 | |
0.8 | 0.7782 | 15 | 15 | |
0.9 | 0.7770 | 16 | 16 | |
1.0 | 0.7761 | 16 | 16 | |
1.5 | 0.7695 | 19 | 19 | |
2.0 | 0.7648 | 21 | 21 | |
30 | 0.1 | 0.7812 | 4 | 4 |
0.2 | 0.7844 | 7 | 7 | |
0.3 | 0.7850 | 9 | 9 | |
0.4 | 0.7849 | 10 | 10 | |
0.5 | 0.7836 | 12 | 12 | |
0.6 | 0.7817 | 13 | 13 | |
0.7 | 0.7782 | 14 | 14 | |
0.8 | 0.7769 | 15 | 15 | |
0.9 | 0.7750 | 16 | 16 | |
1.0 | 0.7745 | 16 | 16 | |
1.5 | 0.7680 | 20 | 20 | |
2.0 | 0.7654 | 21 | 21 | |
31 | 0.1 | 0.7815 | 4 | 4 |
0.2 | 0.7853 | 8 | 8 | |
0.3 | 0.7846 | 10 | 10 | |
0.4 | 0.7849 | 11 | 11 | |
0.5 | 0.7838 | 12 | 12 | |
0.6 | 0.7816 | 13 | 13 | |
0.7 | 0.7778 | 15 | 15 | |
0.8 | 0.7752 | 16 | 16 | |
0.9 | 0.7733 | 16 | 16 | |
1.0 | 0.7720 | 17 | 17 | |
1.5 | 0.7682 | 20 | 20 | |
2.0 | 0.7622 | 22 | 22 | |
32 | 0.1 | 0.7819 | 5 | 5 |
0.2 | 0.7848 | 7 | 7 | |
0.3 | 0.7842 | 10 | 10 | |
0.4 | 0.7848 | 11 | 11 | |
0.5 | 0.7835 | 12 | 12 | |
0.6 | 0.7797 | 14 | 14 | |
0.7 | 0.7764 | 15 | 15 | |
0.8 | 0.7746 | 16 | 16 | |
0.9 | 0.7735 | 17 | 17 | |
1.0 | 0.7719 | 17 | 17 | |
1.5 | 0.7677 | 20 | 20 | |
2.0 | 0.7636 | 22 | 22 | |
33 | 0.1 | 0.7801 | 5 | 5 |
0.2 | 0.7832 | 8 | 8 | |
0.3 | 0.7838 | 10 | 10 | |
0.4 | 0.7839 | 12 | 12 | |
0.5 | 0.7810 | 13 | 13 | |
0.6 | 0.7778 | 14 | 14 | |
0.7 | 0.7743 | 15 | 15 | |
0.8 | 0.7721 | 16 | 16 | |
0.9 | 0.7702 | 17 | 17 | |
1.0 | 0.7692 | 18 | 18 | |
1.5 | 0.7645 | 21 | 21 | |
2.0 | 0.7598 | 22 | 22 | |
34 | 0.1 | 0.7785 | 5 | 5 |
0.2 | 0.7819 | 8 | 8 | |
0.3 | 0.7818 | 10 | 10 | |
0.4 | 0.7817 | 12 | 12 | |
0.5 | 0.7786 | 13 | 13 | |
0.6 | 0.7761 | 15 | 15 | |
0.7 | 0.7734 | 16 | 16 | |
0.8 | 0.7714 | 16 | 16 | |
0.9 | 0.7688 | 17 | 17 | |
1.0 | 0.7679 | 18 | 18 | |
1.5 | 0.7616 | 21 | 21 | |
2.0 | 0.7577 | 23 | 23 | |
35 | 0.1 | 0.7774 | 6 | 6 |
0.2 | 0.7822 | 8 | 8 | |
0.3 | 0.7813 | 10 | 10 | |
0.4 | 0.7809 | 12 | 12 | |
0.5 | 0.7787 | 13 | 13 | |
0.6 | 0.7755 | 15 | 15 | |
0.7 | 0.7733 | 16 | 16 | |
0.8 | 0.7726 | 17 | 17 | |
0.9 | 0.7703 | 18 | 18 | |
1.0 | 0.7695 | 19 | 19 | |
1.5 | 0.7620 | 21 | 21 | |
2.0 | 0.7587 | 24 | 24 | |
36 | 0.1 | 0.7769 | 6 | 6 |
0.2 | 0.7822 | 8 | 8 | |
0.3 | 0.7817 | 10 | 10 | |
0.4 | 0.7806 | 12 | 12 | |
0.5 | 0.7787 | 13 | 13 | |
0.6 | 0.7749 | 15 | 15 | |
0.7 | 0.7726 | 16 | 16 | |
0.8 | 0.7727 | 17 | 17 | |
0.9 | 0.7700 | 18 | 18 | |
1.0 | 0.7691 | 19 | 19 | |
1.5 | 0.7619 | 22 | 22 | |
2.0 | 0.7589 | 23 | 23 | |
37 | 0.1 | 0.7763 | 6 | 6 |
0.2 | 0.7823 | 9 | 9 | |
0.3 | 0.7821 | 10 | 10 | |
0.4 | 0.7809 | 12 | 12 | |
0.5 | 0.7787 | 14 | 14 | |
0.6 | 0.7757 | 16 | 16 | |
0.7 | 0.7750 | 17 | 17 | |
0.8 | 0.7745 | 18 | 18 | |
0.9 | 0.7724 | 19 | 19 | |
1.0 | 0.7709 | 20 | 20 | |
1.5 | 0.7642 | 22 | 22 | |
2.0 | 0.7603 | 24 | 24 | |
38 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7841 | 9 | 9 | |
0.3 | 0.7845 | 11 | 11 | |
0.4 | 0.7821 | 13 | 13 | |
0.5 | 0.7805 | 15 | 15 | |
0.6 | 0.7775 | 16 | 16 | |
0.7 | 0.7784 | 18 | 18 | |
0.8 | 0.7766 | 19 | 19 | |
0.9 | 0.7755 | 20 | 20 | |
1.0 | 0.7751 | 21 | 21 | |
1.5 | 0.7694 | 23 | 23 | |
2.0 | 0.7651 | 25 | 25 | |
39 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7841 | 11 | 11 | |
0.4 | 0.7820 | 13 | 13 | |
0.5 | 0.7801 | 15 | 15 | |
0.6 | 0.7772 | 16 | 16 | |
0.7 | 0.7777 | 18 | 18 | |
0.8 | 0.7760 | 19 | 19 | |
0.9 | 0.7753 | 20 | 20 | |
1.0 | 0.7751 | 21 | 21 | |
1.5 | 0.7682 | 23 | 23 | |
2.0 | 0.7654 | 25 | 25 | |
40 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7843 | 11 | 11 | |
0.4 | 0.7821 | 13 | 13 | |
0.5 | 0.7804 | 15 | 15 | |
0.6 | 0.7767 | 16 | 16 | |
0.7 | 0.7772 | 18 | 18 | |
0.8 | 0.7754 | 19 | 19 | |
0.9 | 0.7745 | 20 | 20 | |
1.0 | 0.7749 | 21 | 21 | |
1.5 | 0.7680 | 23 | 23 | |
2.0 | 0.7643 | 25 | 25 | |
41 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7841 | 11 | 11 | |
0.4 | 0.7819 | 13 | 13 | |
0.5 | 0.7791 | 14 | 14 | |
0.6 | 0.7752 | 16 | 16 | |
0.7 | 0.7751 | 18 | 18 | |
0.8 | 0.7730 | 19 | 19 | |
0.9 | 0.7729 | 20 | 20 | |
1.0 | 0.7719 | 21 | 21 | |
1.5 | 0.7652 | 23 | 23 | |
2.0 | 0.7616 | 25 | 25 | |
42 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7841 | 9 | 9 | |
0.3 | 0.7845 | 11 | 11 | |
0.4 | 0.7819 | 13 | 13 | |
0.5 | 0.7789 | 14 | 14 | |
0.6 | 0.7749 | 16 | 16 | |
0.7 | 0.7740 | 18 | 18 | |
0.8 | 0.7726 | 19 | 19 | |
0.9 | 0.7722 | 20 | 20 | |
1.0 | 0.7702 | 21 | 21 | |
1.5 | 0.7615 | 23 | 23 | |
2.0 | 0.7570 | 25 | 25 | |
43 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7843 | 11 | 11 | |
0.4 | 0.7818 | 13 | 13 | |
0.5 | 0.7789 | 14 | 14 | |
0.6 | 0.7751 | 16 | 16 | |
0.7 | 0.7739 | 18 | 18 | |
0.8 | 0.7728 | 19 | 19 | |
0.9 | 0.7723 | 20 | 20 | |
1.0 | 0.7702 | 21 | 21 | |
1.5 | 0.7612 | 23 | 23 | |
2.0 | 0.7553 | 25 | 25 | |
44 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7841 | 11 | 11 | |
0.4 | 0.7819 | 13 | 13 | |
0.5 | 0.7789 | 14 | 14 | |
0.6 | 0.7751 | 16 | 16 | |
0.7 | 0.7740 | 18 | 18 | |
0.8 | 0.7726 | 19 | 19 | |
0.9 | 0.7722 | 20 | 20 | |
1.0 | 0.7701 | 21 | 21 | |
1.5 | 0.7607 | 24 | 24 | |
2.0 | 0.7552 | 26 | 26 | |
45 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7843 | 11 | 11 | |
0.4 | 0.7819 | 13 | 13 | |
0.5 | 0.7789 | 14 | 14 | |
0.6 | 0.7749 | 16 | 16 | |
0.7 | 0.7739 | 18 | 18 | |
0.8 | 0.7726 | 19 | 19 | |
0.9 | 0.7722 | 20 | 20 | |
1.0 | 0.7700 | 21 | 21 | |
1.5 | 0.7607 | 24 | 24 | |
2.0 | 0.7552 | 26 | 26 | |
46 | 0.1 | 0.7779 | 6 | 6 |
0.2 | 0.7843 | 9 | 9 | |
0.3 | 0.7843 | 11 | 11 | |
0.4 | 0.7819 | 13 | 13 | |
0.5 | 0.7789 | 14 | 14 | |
0.6 | 0.7751 | 16 | 16 | |
0.7 | 0.7739 | 18 | 18 | |
0.8 | 0.7727 | 19 | 19 | |
0.9 | 0.7721 | 20 | 20 | |
1.0 | 0.7700 | 21 | 21 | |
1.5 | 0.7607 | 24 | 24 | |
2.0 | 0.7554 | 26 | 26 |
column = table["AUCs"]
column.idxmax()
(15, 0.6)
table.loc[column.idxmax()]
AUCs 0.7898 sympt_cnt 11.0000 vars_cnt 11.0000 Name: (15, 0.6), dtype: float64
n_cluster = column.idxmax()[0]
_C =column.idxmax()[1]
n_cluster, _C
(15, 0.6)
# read data
path = "../data/preprocessed.csv"
df = pd.read_csv(path)
df.columns = [s.replace('_',' ') for s in df.columns]
metric = "cityblock"
method = "complete"
data = df[df['TestPositive'] == 1]
data = data.drop(columns=['TestPositive'])
plt.figure(figsize=(40, 20))
shc.set_link_color_palette(['black'])
dend3 = shc.dendrogram(shc.linkage(data.T, method=method, metric=metric), color_threshold=6500,
leaf_font_size=30, labels=data.T.index, count_sort=True, leaf_rotation=90)
plt.axhline(y=33, c='grey', lw=5, linestyle='dashed')
plt.yticks(fontsize=30)
plt.show()
data = df[df['TestPositive'] == 0]
data = data.drop(columns=['TestPositive'])
plt.figure(figsize=(40, 20))
shc.set_link_color_palette(['black'])
dend3 = shc.dendrogram(shc.linkage(data.T, method=method, metric=metric), color_threshold=15400,
leaf_font_size=30, labels=data.T.index, count_sort=True, leaf_rotation=90)
plt.axhline(y=89, c='grey', lw=5, linestyle='dashed')
shc.set_link_color_palette(None) # reset to default after use
plt.yticks(fontsize=30)
plt.show()
X, Xt, y, yt = cut_hierarchy_columns(df, df, n_cluster)
Xy_pickle = 'BinaryDataX_hierarchical.p'
pickle.dump({'X': X, 'y': y}, open(Xy_pickle, 'wb'))
def do_validation(fold, tr, te, c= _C):
data = pickle.load(open(Xy_pickle, 'rb'))
X, y = data['X'], data['y']
X_tr, X_te = X.iloc[tr], X.iloc[te]
y_tr, y_te = y.iloc[tr].values.ravel(), y.iloc[te].values.ravel()
print(f'fold {fold:02}')
regressor = LogisticRegression(penalty='l1', solver='saga', C=c, n_jobs=-1, max_iter=5000*2)
regressor.fit(X_tr, y_tr)
y_pr = regressor.predict_proba(X_te)[:,1]
score = roc_auc_score(y_te, y_pr)
print(f'fold {fold:02}, score={score:.5f}')
return score, regressor.coef_[0]
skf = StratifiedKFold(n_splits=24, shuffle=True, random_state=37)
outputs = Parallel(n_jobs=-1)(delayed(do_validation)(fold, tr, te, _C)
for fold, (tr, te) in enumerate(skf.split(X, y)))
scores = pd.Series([score for score, _ in outputs])
coefs = pd.DataFrame([coef for _, coef in outputs], columns=X.columns)
def get_profile(df, col):
s = df[col]
s_pos = s[s > 0]
s_neg = s[s < 0]
n = df.shape[0]
p_pos = len(s_pos) / n
p_neg = len(s_neg) / n
return {
'field': col,
'n_pos': len(s_pos),
'n_neg': len(s_neg),
'pct_pos': p_pos,
'pct_neg': p_neg,
'is_valid': 1 if p_pos >= 0.95 or p_neg >= 0.95 else 0
}
valid_coefs = pd.DataFrame([get_profile(coefs, c) for c in coefs.columns]).sort_values(['is_valid'], ascending=False)
valid_coefs = valid_coefs[valid_coefs.is_valid == 1]
valid_cols = list(valid_coefs.field)
regressor = LogisticRegression(penalty='l1', solver='saga', C= _C, n_jobs=-1,
max_iter=5000*2, random_state=37)
regressor.fit(X[valid_cols], y.values.ravel())
LogisticRegression(C=0.6, max_iter=10000, n_jobs=-1, penalty='l1', random_state=37, solver='saga')
y_pred = regressor.predict_proba(X[valid_cols])[:,1]
t = X[valid_cols].copy()
t['y_pred'] = y_pred
t['y_actual'] = y
t.to_csv("../data/results/prediction_hierarchical_model.csv", index=False)
c = pd.Series(regressor.coef_[0], valid_cols)
plt.style.use('ggplot')
i = pd.Series([regressor.intercept_[0]], index=['intercept'])
s = pd.concat([c[c > 0], c[c < 0]]).sort_index()
s = pd.concat([i, s])
color = ['r' if v > 0 else 'b' for v in s]
ax = s.plot(kind='bar', color=color, figsize=(20, 4))
_ = ax.set_title(f'Logistic Regression, validated auc={scores.mean():.5f}')
s_odds = np.exp(s)
color = ['r' if v > 1 else 'b' for v in s_odds]
ax = s_odds.plot(kind='bar', color=color, figsize=(20, 4))
_ = ax.set_title(f'Logistic Regression, coefficient odds')
pd.DataFrame({
'coefficient': s,
'coefficient_odds': s_odds
}).round(decimals=2)
coefficient | coefficient_odds | |
---|---|---|
intercept | -2.03 | 0.13 |
Age 30 and over | -0.55 | 0.58 |
Cough & Fever & Headaches & Runny nose | 1.57 | 4.81 |
Female | 0.33 | 1.39 |
Headaches | 0.61 | 1.85 |
History of respiratory symptoms | 0.24 | 1.27 |
Loss of taste | 0.77 | 2.15 |
Muscle aches | 0.22 | 1.25 |
Race White | 0.38 | 1.46 |
Runny nose | -0.55 | 0.57 |
Shortness of breath & Wheezing & Chills & Chest pain & Difficulty breathing | 2.22 | 9.18 |
(pd.DataFrame({
'coefficient': s,
'coefficient_odds': s_odds
}).round(decimals=4)).shape
(11, 2)