"""Graph -> representation."""
# =============================================================================
# IMPORTS
# =============================================================================
import abc
from typing import Optional
import torch
import dgl
import functools
from dgl.nn.pytorch import GraphConv
# =============================================================================
# BASE CLASSES
# =============================================================================
[docs]class Representation(torch.nn.Module, abc.ABC):
"""Base class for a representation.
Methods
-------
forward(g)
Project a graph onto a fixed-dimensional space.
"""
[docs] def __init__(self, out_features, *args, **kwargs) -> torch.Tensor:
super(Representation, self).__init__()
self.out_features = out_features
[docs] @abc.abstractmethod
def forward(self, g) -> torch.Tensor:
"""Forward pass.
Parameters
----------
g : dgl.DGLBatchedGraph
Input graph.
"""
raise NotImplementedError
# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class DGLRepresentation(Representation):
""" Representation with DGL layer. """
[docs] def __init__(
self,
layer: type = functools.partial(GraphConv, allow_zero_in_degree=True),
in_features: int = 74, # TODO(yuanqing-wang): make this less awkward?
hidden_features: Optional[int] = None,
out_features: int = 128,
depth: int = 3,
activation: callable = torch.nn.SiLU(),
global_pool: str = "sum",
):
super(DGLRepresentation, self).__init__(out_features=out_features)
if hidden_features is None:
hidden_features = out_features
self.embedding_in = torch.nn.Sequential(
torch.nn.Linear(in_features, hidden_features),
activation,
)
# construct model
for idx in range(depth):
setattr(
self,
"gn%s" % idx,
layer(hidden_features, hidden_features),
)
self.embedding_out = torch.nn.Sequential(
torch.nn.Linear(hidden_features, hidden_features),
activation,
)
self.ff = torch.nn.Linear(hidden_features, out_features)
self.depth = depth
self.global_pool = getattr(dgl, "%s_nodes" % global_pool)
self.activation = activation
[docs] def forward(self, g, field="h"):
"""Forward pass.
Parameters
----------
g : dgl.DGLGraph
Input graph.
Returns
-------
torch.Tensor
The result
Examples
--------
>>> import malt
>>> molecule = malt.Molecule("C")
>>> representation = DGLRepresentation(out_features=8)
>>> h = representation(molecule.g)
>>> assert h.shape == (1, 8)
"""
# make local copy
g = g.local_var()
h = g.ndata[field]
h = self.embedding_in(h)
# loop through the depth
for idx in range(self.depth):
h = getattr(self, "gn%s" % idx)(g, h)
h = self.activation(h)
h = self.embedding_out(h)
g.ndata[field] = h
# global pool
h = self.global_pool(g, field)
# final feedforward
h = self.ff(h)
return h