Algorithm 1 Implementation of the ConvCNPs–SDE model |
Inputs: ID dataset ; CR and MR are the context rate and missing rate, respectively; ccnps represents the ConvCNPs model of vNPs for completing the ID dataset; is the downsampling net for 2D image classification tasks or the upsampling net for 1D regression tasks; is the fully connected net; f represents the drift net and g represents the diffusion net; t is the layer depth; is the cross-entropy loss function, is the log-likelihood loss function, and is the binary cross-entropy loss function. Outputs: Means and Vars for #training iterations do |
1. Sample a minibatch of m data: ; |
2. if for 1D regression task: |
3. Context points are generated from sampled target points based on CR, where equals ; |
4. Forward through the ConvCNPs model: Y_dist = ccnps; |
5. Forward through the upsampling net of the SDE-Net block: ; |
6. else for 2D image classification task: |
7. Forward through the ConvCNPs model: Y_dist = ccnps; |
8. Forward through the downsampling net of the SDE-Net block: ; |
9. for k = 0 to t − 1 do |
10. Sample ; ; |
11. end for |
12. Forward through the fully connected layer of the SDE-Net block: ; |
13. Update and f by ; |
14. Update ccnps by ; |
15. Sample a minibatch of data from ID: ; |
16. Sample a minibatch of data from OOD: ; |
17. Forward through the downsampling or upsampling nets of the SDE-Net block: ; |
18. Update g by ; |
for #testing iterations do |
19. Evaluate the of ConvCNPs–SDE model; |
20. Sample a minibatch of m data from ID: ; |
21. mask = Bernoulli (1-MR) |
22. masked_ = mask ∗ ; |
23. completed_= ccnps; |
24. Means, Vars = SDE-Net(completed_); |