Extended Data Fig 6. PyTorch-style pseudocode for transformer-based masked multi-label classification.
Inputs to our masked multi-label classification algorithm are listed in lines 1–5. The vision encoder and genetic encoder are pretrained in our implementation but can be randomly initialized and trained end-to-end. The label mask is an L-dimensional binary mask with a variable percentage of the labels removed and subsequently predicted in each feedforward pass. An image is augmented and undergoes a feedforward pass through the vision encoder . The image representation is then normalized. The labels are embedded using our pretrained genetic embedding model and the label mask is applied. The label embeddings are then concatenated with the image embedding and passed into the transformer encoder as input tokens. Unlike previous transformer-based methods for multi-label classification [31], we enforce that the transformer encoder outputs into the same vector space as the pretrained genetic embedding model. We perform a batch matrix multiplication with the transformer outputs and the embedding layer weights. The main diagonal elements are the inner product between the transformer encoder output and the corresponding embedding weight values. We then compute the masked binary cross-entropy loss. In our implementation, this is used to train the transformer encoder model only.