Skip to main content
. 2022 Aug 15;13:4654. doi: 10.1038/s41467-022-31985-y

Fig. 2. Multi-head residual attention network architecture, performance, and visualisations for human interpretation.

Fig. 2

a The multi-head network architecture consist of a single shared Attention-56 network58 backbone, which contains stacked attention modules and residual blocks, followed by four separate fully connected output heads after the flattening layer, one for each parameter. Each of these heads classifies its associated parameter as either low, good, or high. Attention modules consist of a trunk branch containing residual blocks and a mask branch which performs down- and up-sampling. b Example attention masks at each module for the given input images. Each module output consists of many channels of masks, only a single sample is shown here. The masks show regions the network is focussing on, such as the most recent extrusion as shown by the output of module 2. c Confusion matrices of the final network after the three stages of training on our test dataset for each parameter. d Training and validation accuracy plots from training the network across three seeds, smoothed with an exponential moving average, on three datasets: single layer, full and balanced. e Example data augmentations used during training to make the model more generalisable.