|
Algorithm 1 Pseudocode of FM in a PyTorch-like style |
-
1:
:
query/key embeddings and text embedding. (BC)
-
2:
:
queue of N keys (CN)
-
3:
: temperatures for student/teacher (scalars)
-
4:
-
5:
noise_for_q = torch.randn_like(q) × noise_std # Gauss noise
-
6:
noise_for_t = torch.randn_like(z_t) × noise_std
-
7:
-
8:
l_a = torch.mm(z_q + noise_for_q, queue_a) # compute similarities
-
9:
l_b = torch.mm(z_q + noise_for_q, z_t + noise_for_t)
-
10:
-
11:
loss_kl = loss_kld (l_b/tau_s, z_t/tau_t)
-
12:
-
13:
def
-
14:
loss_kld (inputs, targets):
-
15:
inputs, targets = F.log_softmax(inputs, dim = 1), F.softmax(
-
16:
targets, dim = 1)
-
17:
return F.kl_div(inputs, targets, reduction = ’batchmean’)
|