Training a CNN model to predict relative activities of compromised sgRNAs

This python notebook accompanies the manuscript "Titrating gene expression using libraries of systematically attenuated CRISPR guide RNAs" by Jost and Santos et al., 2019.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
import itertools
from scipy.stats import pearsonr
from collections import Counter
from keras.layers import Conv2D, Activation, MaxPool2D, Flatten, Dense, Dropout
from keras.models import Sequential
Using TensorFlow backend.
In [2]:
np.set_printoptions(linewidth=100)
In [3]:
%matplotlib inline
In [ ]:
 

import data

  • Supplementary Table S8 contains the input data for training this model.
  • Mean relative gamma values have been calculated from the large-scale screen, with the following filtering criteria:
    1. sgRNAs with two mismatches have been excluded. The invariant G at the 5' end of each sgRNA is never counted as a mismatch.
    2. sgRNA series in which the perfectly-matched sgRNA had a growth phenotype less than 10 z-scores outside the distribution of negative control phenotypes were excluded.
    3. Mean relative gammas were calculated from K562 and Jurkat cells, two biological replicates each, if both cell types passed the z-score filter (4 measurements). If only one cell type passed the filter, the mean relative gamma was calculated from biological replicates in that cell only (2 measurements).
  • Genome and sgRNA input sequences include the protospacer with 2 and 4 flanking genomic bases on the PAM-distal and PAM-proximal ends, respectively. Each sequence is 26 bases in length.
In [4]:
data = pd.read_table('tables/Table_S8_machine_learning_input.txt', index_col=0)
In [5]:
data.head()
Out[5]:
perfect match sgRNA gene sgRNA sequence mismatch position new pairing K562 Jurkat mean relative gamma genome input sgRNA input
AAR2_-_34824434.23-P1P2_01 AAR2_-_34824434.23-P1P2 AAR2 GTGAGGCGAGGCGGTGAGTG -17.0 rA:dC True False 0.665625 ACGTGGGGCGAGGCGGTGAGTGTGGC ACGTGAGGCGAGGCGGTGAGTGTGGC
AAR2_-_34824434.23-P1P2_02 AAR2_-_34824434.23-P1P2 AAR2 GTGGGACGAGGCGGTGAGTG -15.0 rA:dC True False 0.596488 ACGTGGGGCGAGGCGGTGAGTGTGGC ACGTGGGACGAGGCGGTGAGTGTGGC
AAR2_-_34824434.23-P1P2_03 AAR2_-_34824434.23-P1P2 AAR2 GTGGGGGGAGGCGGTGAGTG -14.0 rG:dG True False 0.629915 ACGTGGGGCGAGGCGGTGAGTGTGGC ACGTGGGGGGAGGCGGTGAGTGTGGC
AAR2_-_34824434.23-P1P2_04 AAR2_-_34824434.23-P1P2 AAR2 GTGGGGCGAAGCGGTGAGTG -11.0 rA:dC True False -0.012634 ACGTGGGGCGAGGCGGTGAGTGTGGC ACGTGGGGCGAAGCGGTGAGTGTGGC
AAR2_-_34824434.23-P1P2_05 AAR2_-_34824434.23-P1P2 AAR2 GTGGGGCGACGCGGTGAGTG -11.0 rC:dC True False 0.058972 ACGTGGGGCGAGGCGGTGAGTGTGGC ACGTGGGGCGACGCGGTGAGTGTGGC
In [6]:
# print some statistics from this dataset

print 'total sgRNAs:', len(data)
print 'total series:', len(data['perfect match sgRNA'].unique())
print 'total genes:', len(data['gene'].unique())
print 'phenotypes from K562 and Jurkat: %.2f%%'%(len(data[(data.K562==True) & (data.Jurkat==True)])/
                                                 float(len(data))*100)
total sgRNAs: 26248
total series: 2034
total genes: 1292
phenotypes from K562 and Jurkat: 39.61%
In [ ]:
 

generate X and y (feature and target arrays)

In [7]:
# make a list of tuples pairing genome and input sequences

sequence_tuples = zip(data['genome input'], data['sgRNA input'])
In [8]:
def binarize_sequence(sequence):
    """
    converts a 26-base nucleotide string to a binarized array of shape (4,26)
    """
    arr = np.zeros((4, 26))
    for i in range(26):
        if sequence[i] == 'A':
            arr[0,i] = 1
        elif sequence[i] == 'C':
            arr[1,i] = 1
        elif sequence[i] == 'G':
            arr[2,i] = 1
        elif sequence[i] == 'T':
            arr[3,i] = 1
        else:
            raise Exception('sequence contains characters other than A,G,C,T \n%s'%sequence)
 
    return arr
In [9]:
# example of what the above function is doing
# using the genome input sequence from the first row of the data table

test_sequence = sequence_tuples[0][0]
print test_sequence, '\n'
print binarize_sequence(test_sequence), '\n'
print binarize_sequence(test_sequence).shape, '\n'

fig,ax=plt.subplots()
ax.imshow(binarize_sequence(test_sequence))
ax.set_yticks(range(4))
ax.set_yticklabels(list('ACGT'))
ax.set_xticks([]);
ACGTGGGGCGAGGCGGTGAGTGTGGC 

[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 1. 1. 0. 1. 0. 1. 0. 1. 0. 1. 1. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 0.]] 

(4, 26) 

In [10]:
# for each tuple, binarize the sequences and stack them into a 3D array of shape (4,26,2)

stacked_arrs = [np.stack((binarize_sequence(genome_input), binarize_sequence(sgrna_input)), axis=2) 
                for (genome_input, sgrna_input) in sequence_tuples]
In [11]:
# the feature input X is a 4D array containing all of the 3D arrays generated above

X = np.concatenate([arr[np.newaxis] for arr in stacked_arrs])
In [12]:
# the target input y is a 1D array of relative activities

y = data['mean relative gamma'].values
In [13]:
# an array of series IDs will allow mapping of each element in X or y
# to be used in assigning each series to the training or validation set

series = data['perfect match sgRNA']
In [14]:
# check the shape of each array

print 'X:', X.shape
print 'y:', y.shape
print 'series:', series.shape
X: (26248, 4, 26, 2)
y: (26248,)
series: (26248,)
In [ ]:
 

split into training and validation sets

In [15]:
# randomly select 20% of sgRNA series to be set aside for validation

np.random.seed(99)
val_series = np.random.choice(np.unique(series), size=int(len(np.unique(series))*.20), replace=False)
val_indices = np.where(np.isin(series, val_series))
train_indices = np.where(~np.isin(series, val_series))
In [16]:
# define train and validation sets

X_train = X[train_indices]
X_val = X[val_indices]
y_train = y[train_indices]
y_val = y[val_indices]
In [17]:
# check the shape of each array

print 'X train:', X_train.shape
print 'y train:', y_train.shape, '\n'
print 'X validation:', X_val.shape
print 'y validation:', y_val.shape
X train: (21007, 4, 26, 2)
y train: (21007,) 

X validation: (5241, 4, 26, 2)
y validation: (5241,)
In [ ]:
 

calculate sample weights

  • Since the activity distribution is heavily imbalanced (the vast majority of sgRNAs are inactive), the contribution of each training instance to the model will be inversely proportional to the total number of instances with similar activity. Therefore the model will be penalized more for errors associated with intermediate activity sgRNAs, which are relatively rare.
  • The inactive guide class ("class 0") is by far the most populated, and therefore the model's performance will be highly sensitive to the weight assigned to that class. During hyperparameter optimization we also included a scaling factor to adjust this weight, and found that a slight increase by 1.5x minimized the errors associated with intermediate activity sgRNAs.
In [18]:
# assign training target values to 5 bins with evenly spaced edges
# relative activities below 0 or above 1 are included in the lowest or highest bin, respectively

nbins=5
y_train_clipped = y_train.clip(0,1)
y_train_binned, histbins = pd.cut(y_train_clipped, np.linspace(0,1,nbins+1), labels=range(nbins), include_lowest=True, retbins=True)
In [19]:
print 'bin edges:', histbins
bin edges: [0.  0.2 0.4 0.6 0.8 1. ]
In [20]:
# calculate a weight for each bin, inversely proportional to the population in that bin

class_weights = {k:1/float(v) for k,v in Counter(y_train_binned).items()}
In [21]:
# increase the class 0 weight by multiplying by 1.5
# this empirically improved model accuracy during parameter optimization on the training data

class_weights[0] = class_weights[0] * 1.5
In [22]:
# scale weights to sum to 1

class_weights = {k:v/sum(class_weights.values()) for k,v in class_weights.items()}
class_weights
Out[22]:
{0: 0.050152613839668826,
 1: 0.2815653766397307,
 2: 0.2879094606727001,
 3: 0.24366538202754034,
 4: 0.13670716682036013}
In [23]:
# generate a list mapping each element in y_train to its class weight

sample_weights = [class_weights[Y] for Y in y_train_binned]
In [ ]:
 

build and train CNN model

  • Hyperparameters were selected based on a cross-validated randomized grid search, using only the training data.
  • For illustrative purposes only one model is being trained in the example below; in practice we trained 20 such models on the same data, and determined predicted activities based on the average prediction of the ensemble.
In [24]:
# build and compile the CNN model

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=(4,4), activation='relu', padding='same', input_shape=(4,26,2), data_format='channels_last'))
model.add(Conv2D(filters=32, kernel_size=(4,4), activation='relu', padding='same', data_format='channels_last'))
model.add(MaxPool2D(pool_size=(1,2), padding='same', data_format='channels_last'))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(units=128, activation='sigmoid'))
model.add(Dropout(0.1))
model.add(Dense(units=32, activation='sigmoid'))
model.add(Dropout(0.1))
model.add(Dense(1, activation='linear'))
model.compile(loss='logcosh', metrics=['mse'], optimizer='adam')
In [25]:
# train the model for 8 epochs

model_history = model.fit(X_train, 
                          y_train.ravel(), 
                          sample_weight=np.array(sample_weights), 
                          batch_size=32, 
                          epochs=8, 
                          validation_data=(X_val, y_val.ravel()))
Train on 21007 samples, validate on 5241 samples
Epoch 1/8
21007/21007 [==============================] - 10s 471us/step - loss: 0.0068 - mean_squared_error: 0.1572 - val_loss: 0.0412 - val_mean_squared_error: 0.0859
Epoch 2/8
21007/21007 [==============================] - 9s 437us/step - loss: 0.0042 - mean_squared_error: 0.0827 - val_loss: 0.0365 - val_mean_squared_error: 0.0756
Epoch 3/8
21007/21007 [==============================] - 9s 441us/step - loss: 0.0038 - mean_squared_error: 0.0729 - val_loss: 0.0361 - val_mean_squared_error: 0.0750
Epoch 4/8
21007/21007 [==============================] - 9s 446us/step - loss: 0.0035 - mean_squared_error: 0.0664 - val_loss: 0.0353 - val_mean_squared_error: 0.0734
Epoch 5/8
21007/21007 [==============================] - 9s 442us/step - loss: 0.0033 - mean_squared_error: 0.0618 - val_loss: 0.0299 - val_mean_squared_error: 0.0622
Epoch 6/8
21007/21007 [==============================] - 9s 442us/step - loss: 0.0031 - mean_squared_error: 0.0591 - val_loss: 0.0284 - val_mean_squared_error: 0.0592
Epoch 7/8
21007/21007 [==============================] - 9s 443us/step - loss: 0.0030 - mean_squared_error: 0.0555 - val_loss: 0.0321 - val_mean_squared_error: 0.0665
Epoch 8/8
21007/21007 [==============================] - 9s 440us/step - loss: 0.0028 - mean_squared_error: 0.0522 - val_loss: 0.0310 - val_mean_squared_error: 0.0645
In [26]:
# plot measured vs. predicted relative activity

fig,ax = plt.subplots(figsize=(5,5))
ax.scatter(model.predict(X_val), y_val, marker='.', alpha=.2)
ax.set_xlabel('predicted activity')
ax.set_ylabel('measured activity')
ax.set_title('performance on validation set');
In [27]:
print 'r squared = %.3f'%pearsonr(y_val, model.predict(X_val).ravel())[0]**2
r squared = 0.618
In [ ]:
 

get predicted activity of an arbitrary mismatched sgRNA

In [28]:
# starting from the sgRNA and corresponding genomic sequences

                             #
sgrna  = 'TAGGTACTGAGCGCGCGAGCTGAGGA'
genome = 'TAGGTACTGAGCGCGCGAGGTGAGGA'
                             # -3 rC:dC
In [29]:
def get_predicted_activity(genome_seq, sgrna_seq, cnn_model):
    """
    takes 26-nt sgRNA and genome sequences, plus a trained model
    outputs the predicted relative activity of the sgRNA
    """
    X = np.stack((binarize_sequence(genome_seq), 
                  binarize_sequence(sgrna_seq)), axis=2)[np.newaxis]
    
    return cnn_model.predict(X)[0][0]
In [30]:
print get_predicted_activity(genome, sgrna, model)
0.06621526
In [ ]: