Skip to content

matrix_variate_multitask_model

HetergeneousMatrixVariateMean

Bases: MultitaskMean

Computes a mean depending on the input.

Our mean can be the mean of either of the two related GaussianProcesses

Xdot = F(X)ᵀU

or

Y = F(X)ᵀ

We take input in the form

M, X, U = MXU

where M is the mask, where 1 value means we want Xdot = F(X)ᵀU, while 0 means that we want Y = F(X)ᵀ

Source code in bayes_cbf/matrix_variate_multitask_model.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class HetergeneousMatrixVariateMean(MultitaskMean):
    """
    Computes a mean depending on the input.

    Our mean can be the mean of either of the two related GaussianProcesses

        Xdot = F(X)ᵀU

    or

        Y = F(X)ᵀ

    We take input in the form

        M, X, U = MXU

    where M is the mask, where 1 value means we want Xdot = F(X)ᵀU, while 0
    means that we want Y = F(X)ᵀ
    """
    def __init__(self, mean_module, decoder, matshape, **kwargs):
        num_tasks = prod(matshape)
        super().__init__(mean_module, num_tasks, **kwargs)
        self.decoder = decoder
        self.matshape = matshape

    def mean1(self, UH, mu):
        # TODO: Make this a separate module
        XdotMean = UH.unsqueeze(-2) @ mu # D x n
        output = XdotMean.reshape(-1)
        return output

    def mean2(self, mu):
        # TODO: Make this a separate module
        return mu.reshape(-1)

    def forward(self, MXU):
        assert not torch.isnan(MXU).any()
        B = MXU.shape[:-1]

        Ms, _, UH = self.decoder.decode(MXU)
        assert Ms.size(-1) == 1
        Ms = Ms[..., 0]
        idxs = torch.nonzero(Ms - Ms.new_ones(Ms.size()))
        idxend = torch.min(idxs) if idxs.numel() else Ms.size(-1)
        mu = torch.cat([sub_mean(MXU).unsqueeze(-1)
                        for sub_mean in self.base_means], dim=-1)
        assert not torch.isnan(mu).any()
        mu  = mu.reshape(-1, *self.matshape)
        output = None
        if idxend != 0:
            # assume sorted
            assert (Ms[..., idxend:] == 0).all()
            output = self.mean1(UH[..., :idxend, :], mu[:idxend, ...])

        if Ms.size(-1) != idxend:
            Fmean = self.mean2(mu[idxend:, ...])
            output = torch.cat([output, Fmean]) if output is not None else Fmean
        return output

    def state_dict(self):
        return dict(
            matshape=self.matshape,
            decoder=self.decoder.state_dict()
        )

    def load_state_dict(self, state_dict):
        self.matshape = state_dict.pop('matshape')
        self.decoder.load_state_dict(state_dict['decoder'])