Skip to main content
. Author manuscript; available in PMC: 2023 Aug 23.
Published in final edited form as: Nat Med. 2023 Mar 23;29(4):828–832. doi: 10.1038/s41591-023-02252-4

Extended Data Fig 6. PyTorch-style pseudocode for transformer-based masked multi-label classification.

Extended Data Fig 6.

Inputs to our masked multi-label classification algorithm are listed in lines 1–5. The vision encoder and genetic encoder are pretrained in our implementation but can be randomly initialized and trained end-to-end. The label mask is an L-dimensional binary mask with a variable percentage of the labels removed and subsequently predicted in each feedforward pass. An image x is augmented and undergoes a feedforward pass through the vision encoder f. The image representation is then 2 normalized. The labels are embedded using our pretrained genetic embedding model and the label mask is applied. The label embeddings are then concatenated with the image embedding and passed into the transformer encoder as input tokens. Unlike previous transformer-based methods for multi-label classification [31], we enforce that the transformer encoder outputs into the same vector space as the pretrained genetic embedding model. We perform a batch matrix multiplication with the transformer outputs and the embedding layer weights. The main diagonal elements are the inner product between the transformer encoder output and the corresponding embedding weight values. We then compute the masked binary cross-entropy loss. In our implementation, this is used to train the transformer encoder model only.