"""Datasets and their serving."""
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import malt
import torch
from malt.molecule import Molecule
from typing import Union, Iterable, Optional, List, Any, Callable
# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class Dataset(torch.utils.data.Dataset):
"""A collection of Molecules with functionalities to be compatible with
training and optimization.
Parameters
----------
molecules : List[malt.Molecule]
A list of Molecules.
Methods
-------
featurize(molecules)
Featurize all molecules in the dataset.
view()
Generate a torch.utils.data.DataLoader from this Dataset.
"""
_lookup = None
_extra = None
[docs] def __init__(self, molecules: Optional[List]=None) -> None:
super(Dataset, self).__init__()
if molecules is None:
molecules = []
assert isinstance(molecules, List)
assert all(isinstance(molecule, Molecule) for molecule in molecules)
self.molecules = molecules
def __repr__(self):
return "%s with %s molecules" % (self.__class__.__name__, len(self))
def _construct_lookup(self):
"""Construct lookup table for molecules."""
self._lookup = {mol.smiles: mol for mol in self.molecules}
@property
def lookup(self):
"""Returns the mapping between the SMILES and the molecule. """
if self._lookup is None:
self._construct_lookup()
return self._lookup
def __contains__(self, molecule):
"""Check if a molecule is in the dataset.
Parameters
----------
molecule : malt.Molecule
Examples
--------
>>> molecule = Molecule("CC")
>>> dataset = Dataset([molecule])
>>> Molecule("CC") in dataset
True
>>> Molecule("C") in dataset
False
"""
return molecule.smiles in self.lookup
[docs] def apply(self, function):
"""Apply a function to all molecules in the dataset.
Parameters
----------
function : Callable
The function to be applied to all molecules in this dataset
in place.
Examples
--------
>>> molecule = Molecule("CC")
>>> dataset = Dataset([molecule])
>>> from ..molecule import Molecule
>>> fn = lambda molecule: Molecule(
... smiles=molecule.smiles, metadata={"name": "john"},
... )
>>> dataset = dataset.apply(fn)
>>> dataset[0]["name"]
'john'
"""
self.molecules = [function(molecule) for molecule in self.molecules]
return self
def __eq__(self, other):
"""Determin if two objects are identical."""
if not isinstance(other, self.__class__):
return False
return self.molecules == other.molecules
def __len__(self):
"""Return the number of molecules in the dataset."""
if self.molecules is None:
return 0
return len(self.molecules)
def __getitem__(self, key: Any):
"""Get item from the dataset.
Parameters
----------
key : Any
Notes
-----
* If the key is integer, return the single molecule indexed.
* If the key is a string, return a dataset of all molecules with
this SMILES.
* If the key is a molecule, extract the SMILES string and index by
its SMILES.
* If the key is a tensor, flatten it to treat it as a list.
* If the key is a list, return a dataset with molecules indexed by
the elements in the list.
* If the key is a slice, slice the range and treat at as a list.
"""
if self.molecules is None:
raise RuntimeError("Empty Portfolio.")
if isinstance(key, int):
return self.molecules[key]
elif isinstance(key, str): # NOTE(yuanqing-wang): Are we settled?
return self.__class__(molecules=[self.lookup[key]])
elif isinstance(key, Molecule):
return self.lookup[key.smiles]
elif isinstance(key, torch.Tensor):
key = key.detach().flatten().cpu().numpy().tolist()
elif isinstance(key, list):
return self.__class__(
molecules=[self.molecules[_idx] for _idx in key]
)
elif isinstance(key, slice):
return self.__class__(molecules=self.molecules[key])
else:
raise RuntimeError("The slice is not recognized.")
[docs] def shuffle(self, seed=None):
""" Shuffle the dataset and return it. """
import random
if seed is not None:
random.seed(seed)
random.shuffle(self.molecules)
return self
[docs] def split(self, partition):
"""Split the dataset according to some partition.
Parameters
----------
partition : Sequence[Optional[int, float]]
Splitting partition.
Returns
-------
List[Dataset]
List of datasets split according to the partition.
Examples
--------
>>> dataset = Dataset([Molecule("CC"), Molecule("C")])
>>> dataset0, dataset1 = dataset.split([1, 1])
>>> dataset0[0].smiles
'CC'
"""
n_data = len(self)
partition = [int(n_data * x / sum(partition)) for x in partition]
ds = []
idx = 0
for p_size in partition:
ds.append(self[idx : idx + p_size])
idx += p_size
return ds
def __add__(self, molecules):
"""Combine two datasets and return a new one.
Parameters
----------
molecules : Union[List[Molecule], Dataset]
Molecules to be added to the dataset.
Returns
-------
>>> dataset0 = Dataset([Molecule("C")])
>>> dataset1 = Dataset([Molecule("CC")])
>>> dataset = dataset0 + dataset1
>>> len(dataset)
2
"""
if isinstance(molecules, list):
return self.__class__(molecules=self.molecules + molecules)
elif isinstance(molecules, Dataset):
return self.__class__(
molecules=self.molecules + molecules.molecules
)
else:
raise RuntimeError("Addition only supports list and Dataset.")
def __sub__(self, molecules):
""" Subtract a list of molecules from a dataset and return a new one.
Parameters
----------
molecules : Union[list[Molecule], Dataset]
Molecules to be subtracted from the dataset.
Returns
-------
Dataset
The resulting dataset.
Examples
--------
>>> dataset = Dataset([Molecule("CC"), Molecule("C")])
>>> dataset -= [Molecule("C")]
>>> len(dataset)
1
"""
if isinstance(molecules, list):
molecules = self.__class__(molecules)
return self.__class__(
[
molecule
for molecule in self.molecules
if molecule.smiles not in molecules.lookup
]
)
def __iter__(self):
"""Alias of iter for molecules. """
return iter(self.molecules)
[docs] def append(self, molecule):
"""Append a molecule to the dataset.
Alias of append for molecules.
Note
----
* This append in-place.
Parameters
----------
molecule : molecule
The data molecule to be appended.
"""
self.molecules.append(molecule)
return self
[docs] def featurize_all(self):
""" Featurize all molecules in dataset. """
(molecule.featurize() for molecule in self.molecules)
return self
@property
def smiles(self):
"""Return the list of SMILE strings in the datset. """
return [molecule.smiles for molecule in self.molecules]
@staticmethod
def _batch(
molecules=None, by=['g', 'y'],
**kwargs,
):
"""Batches molecules by provided keys.
Parameters
----------
molecules : list of molecules
Defaults to all molecules in Dataset if none provided.
assay : Union[None, str]
Filter metadata using assay key.
by : Union[Iterable, str]
Attributes of molecule on which to batch.
Returns
-------
ret : Union[tuple, dgl.Graph, torch.Tensor]
Batched data, in order of keys passed in `by` argument.
"""
from collections import defaultdict
ret = defaultdict(list)
# guarantee keys are a list
by = [by] if isinstance(by, str) else by
# loop through molecules
for molecule in molecules:
for key in by:
if key == 'g':
# featurize graphs
if not molecule.is_featurized():
molecule.featurize()
ret['g'].append(molecule.g)
else:
m = molecule.metadata[key]
ret[key].append(m)
# collate batches
for key in by:
if key == 'g':
ret['g'] = dgl.batch(ret['g'])
else:
ret[key] = torch.tensor(ret[key])
# return batches
ret = (*ret.values(), )
if len(ret) < 2:
ret = ret[0]
return ret
[docs] def erase_annotation(self):
"""Erase the metadata. """
for molecule in self.molecules:
molecule.erase_annotation()
return self
[docs] def clone(self):
""" Return a copy of self. """
import copy
return self.__class__(copy.deepcopy(self.molecules))
def batch(self, *args, **kwargs):
return self._batch(self.molecules, *args, **kwargs)
[docs] def view(
self,
collate_fn: Optional[Callable]=None,
by: Union[Iterable, str] = ['g', 'y'],
*args,
**kwargs,
):
"""Provide a data loader from portfolio.
Parameters
----------
collate_fn : Optional[Callable]
The function to gather data molecules.
assay : Union[None, str]
Batch data from molecules using key provided to filter metadata.
by : Union[Iterable, str]
Returns
-------
torch.utils.data.DataLoader
Resulting data loader.
"""
from functools import partial
if collate_fn is None:
# provide default collate function
collate_fn = self._batch
return torch.utils.data.DataLoader(
dataset=self.molecules,
collate_fn=partial(
collate_fn,
by=by,
),
*args,
**kwargs,
)
[docs]def from_pandas(
dataframe,
smiles_column: str,
y_column: str,
):
"""Read dataset from pandas DataFrame.
Parameters
----------
dataframe : pandas.DataFrame
The dataframe to read from.
smiles_column : str
The name of the column containing SMILES string.
y_column : str
The name of the column containing measurement.
Examples
--------
>>> import pandas as pd
>>> dataframe = pd.DataFrame.from_dict(
... {"SMILES": ["C", "CC"], "Y": [1, 2]},
... )
>>> dataset = from_pandas(dataframe, "SMILES", "Y")
"""
smiles_strings = dataframe[smiles_column]
ys = dataframe[y_column]
molecules = []
for smiles, y in zip(smiles_strings, ys):
molecule = Molecule(smiles)
molecule.y = float(y)
molecules.append(molecule)
return Dataset(molecules)