# =============================================================================
# IMPORTS
# =============================================================================
import torch
from typing import Union
# =============================================================================
# MODULE FUNCTIONS
# =============================================================================
[docs]def random(distribution: torch.distributions.Distribution):
return torch.randn(distribution.batch_shape)
[docs]def expectation(distribution: torch.distributions.Distribution):
return distribution.mean
[docs]def uncertainty(distribution: torch.distributions.Distribution):
return distribution.variance
[docs]def expected_improvement(
distribution: torch.distributions.Distribution,
y_best: Union[torch.Tensor, float] = 0.0,
n_samples: int = 64,
):
if isinstance(y_best, float):
y_best = torch.tensor(y_best, device=distribution.mean.device)
improvement = torch.nn.functional.relu(
distribution.sample(torch.Size([n_samples])) - y_best
)
return improvement.mean(axis=0)
[docs]def probability_of_improvement(
distribution: torch.distributions.Distribution,
y_best: Union[torch.Tensor, float] = 0.0,
):
if isinstance(y_best, float):
y_best = torch.tensor(y_best, device=distribution.mean.device)
return 1.0 - distribution.cdf(y_best)
[docs]def upper_confidence_boundary(
distribution: torch.distributions.Distribution,
percentage: Union[torch.Tensor, float] = 0.95,
):
percentage = torch.tensor(percentage)
return distribution.icdf(1 - (1 - percentage) / 2)