Source code for malt.models.supervised_model

"""Wrapper around regressor and representation to hold the entire model. """
# =============================================================================
# IMPORTS
# =============================================================================
import abc
import torch
import gpytorch
from typing import Any
from .regressor import Regressor
from .representation import Representation

# =============================================================================
# BASE CLASSES
# =============================================================================
[docs]class SupervisedModel(torch.nn.Module, abc.ABC): """A supervised model. Parameters ---------- representation : Representation Module to project small molecule graph to latent embeddings. regressor : Regressor Module to convert latent embeddings to likelihood parameters. likelihood : Likelihood Module to convert likelihood parameters and data to probabilities. Methods ------- condition loss """
[docs] def __init__( self, representation: Representation, regressor: Regressor, ) -> None: super(SupervisedModel, self).__init__() assert representation.out_features == regressor.in_features self.representation = representation self.regressor = regressor
[docs] def forward(self, x): """ Make predictive posterior. """ representation = self.representation(x) posterior = self.regressor(representation) return posterior
[docs] def loss(self, x, y): """Default loss function. """ representation = self.representation(x) loss = self.regressor.loss(representation, y) return loss