Source code for malt.agents.player

"""Decision maker. """

import abc
import torch
from typing import Union, Callable
from .agent import Agent
from .merchant import Merchant
from .assayer import Assayer
from malt.models.supervised_model import SupervisedModel
from malt.data.dataset import Dataset
import torch

[docs]class Player(Agent): """ Base class for players. Methods ------- merchandize(points, merchant): Conduct merchanidization. assay(assayer): Conduct assay. """
[docs] def __init__(self): super(Player, self).__init__()
def merchandize( self, dataset: Dataset, merchant: Union[Merchant, None]=None, ): if merchant is None: merchant = self.merchant return merchant.merchandize(dataset) def assay( self, dataset: Dataset, assayer: Union[Assayer, None]=None, ): if assayer is None: assayer = self.assayer return assayer.assay(dataset) @abc.abstractmethod def step(self, *args, **kwargs): raise NotImplementedError
[docs]class ModelBasedPlayer(Player): """ Player with model. Parameters ---------- model : SupervisedModel Model to predict properties based on structure. policy : Callable Policy to rank candidates. trainer: Callable Function to train a model. merchant : Merchant Merchant that merchandizes candidates. assayer : Assayer Assayer that assays the candidates. portfolio : Dataset Initial knowledge about data points. Note ---- 1. Portfolio respects order and could be used to analyze acquisition trajectory. """
[docs] def __init__( self, model: SupervisedModel, policy: Callable, trainer: Callable, merchant: Merchant, assayer: Assayer, portfolio: Union[Dataset, None]=None, ): super(ModelBasedPlayer, self).__init__() self.model = model self.policy = policy self.trainer = trainer self.merchant = merchant self.assayer = assayer if portfolio is None: portfolio = Dataset([]) self.portfolio = portfolio
def merchandize( self, dataset: Dataset, ): return super().merchandize(dataset=dataset) def assay( self, dataset: Dataset, ): dataset = super().assay(dataset=dataset) self.portfolio += dataset return dataset def train(self): self.model = self.trainer(self) return self.model def prioritize(self): if len(self.merchant.catalogue()) == 0: return None self.model.eval() posterior = self.model( self.merchant.catalogue().batch(by=['g']), ) best = self.policy(posterior) return self.merchant.catalogue()[list(best)]
[docs]class SequentialModelBasedPlayer(ModelBasedPlayer): """Model based player with step size equal one. Examples -------- >>> import malt >>> player = SequentialModelBasedPlayer( ... model = malt.models.supervised_model.SupervisedModel( ... representation=malt.models.representation.DGLRepresentation( ... out_features=128 ... ), ... regressor=malt.models.regressor.NeuralNetworkRegressor( ... in_features=128, ... ), ... ), ... policy=malt.policy.UtilityFunction(), ... trainer=malt.trainer.get_default_trainer(), ... merchant=malt.agents.merchant.DatasetMerchant( ... malt.data.collections.linear_alkanes(10), ... ), ... assayer=malt.agents.assayer.DatasetAssayer( ... malt.data.collections.linear_alkanes(10), ... ) ... ) >>> while True: ... if player.step() is None: ... break """
[docs] def __init__(self, *args, **kwargs): super(SequentialModelBasedPlayer, self).__init__( *args, **kwargs )
def step(self): best = self.prioritize() if best is None: return None best = self.merchandize(best) best = self.assay(best) self.train() return best