ESTrainer¶
Base class for Evolution Strategy trainers.
ESTrainer ¶
ESTrainer(params, model, fitness_fn, population_size=50, learning_rate=0.01, sigma=0.1, device=None, seed=None)
Bases: Optimizer, ABC
Base class for Evolution Strategy trainers.
This class provides a framework for implementing various ES algorithms. Subclasses should override the abstract methods to implement specific ES variants (e.g., CMA-ES, OpenAI ES, Natural ES, etc.).
Example
class VanillaESTrainer(ESTrainer): def sample_perturbations(self, population_size): # Return perturbations for each member of the population pass
def compute_update(self, perturbations, fitnesses):
# Compute parameter update from perturbations and fitnesses
pass
Initialize the ES trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Iterable of parameters to optimize (for optimizer compatibility). Typically model.parameters(). |
required | |
model
|
Module
|
PyTorch model to train. Parameters will be optimized. |
required |
fitness_fn
|
Callable[[Module], float]
|
Function that takes a model and returns a fitness score (higher is better). Should handle model evaluation. |
required |
population_size
|
int
|
Number of perturbed models to evaluate per generation |
50
|
learning_rate
|
float
|
Learning rate for parameter updates |
0.01
|
sigma
|
float
|
Standard deviation for parameter perturbations |
0.1
|
device
|
Optional[device]
|
Device to run training on (defaults to model's device) |
None
|
seed
|
Optional[int]
|
Random seed for reproducibility |
None
|
Source code in eggroll_trainer/base.py
Functions¶
compute_update
abstractmethod
¶
Compute parameter update from perturbations and fitnesses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
perturbations
|
Tensor
|
Tensor of shape (population_size, param_dim) |
required |
fitnesses
|
Tensor
|
Tensor of shape (population_size,) containing fitness scores |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor of shape (param_dim,) containing the parameter update |
Source code in eggroll_trainer/base.py
evaluate_fitness ¶
Evaluate fitness of a model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to evaluate |
required |
Returns:
| Type | Description |
|---|---|
float
|
Fitness score (higher is better) |
get_best_model ¶
Get a copy of the model with the best parameters found.
Returns:
| Type | Description |
|---|---|
Module
|
A new model instance with the best parameters loaded |
Source code in eggroll_trainer/base.py
reset ¶
Reset trainer to initial state.
Source code in eggroll_trainer/base.py
sample_perturbations
abstractmethod
¶
Sample perturbations for the population.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
population_size
|
int
|
Number of perturbations to sample |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor of shape (population_size, param_dim) containing perturbations |
Source code in eggroll_trainer/base.py
step ¶
Perform one optimization step (generation).
This method provides compatibility with PyTorch optimizer interface. For ES algorithms, the fitness function is provided at initialization, not per-step. The closure parameter is ignored for ES.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
closure
|
Optional callable (ignored for ES, fitness_fn used instead) |
None
|
Returns:
| Type | Description |
|---|---|
|
Dictionary containing training metrics |
Source code in eggroll_trainer/base.py
train ¶
Train the model for multiple generations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_generations
|
int
|
Number of generations to train |
required |
verbose
|
bool
|
Whether to print progress |
True
|
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary containing final training state |
Source code in eggroll_trainer/base.py
zero_grad ¶
Zero gradients (no-op for ES algorithms).
This method exists for optimizer interface compatibility. ES algorithms don't use gradients, so this is a no-op.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
set_to_none
|
bool
|
If True, set gradients to None (ignored for ES) |
False
|
Source code in eggroll_trainer/base.py
Abstract Methods¶
Subclasses must implement:
sample_perturbations¶
def sample_perturbations(self, population_size: int) -> Tensor:
"""
Sample perturbations for the population.
Args:
population_size: Number of population members
Returns:
Tensor of shape (population_size, param_dim) containing
perturbations for each population member
"""
pass
compute_update¶
def compute_update(
self,
perturbations: Tensor,
fitnesses: Tensor,
) -> Tensor:
"""
Compute parameter update from perturbations and fitnesses.
Args:
perturbations: Tensor of shape (population_size, param_dim)
fitnesses: Tensor of shape (population_size,) with fitness scores
Returns:
Tensor of shape (param_dim,) containing parameter update
"""
pass
Example¶
from eggroll_trainer import ESTrainer
import torch
class CustomESTrainer(ESTrainer):
def sample_perturbations(self, population_size):
param_dim = self.current_params.shape[0]
return torch.randn(
population_size,
param_dim,
device=self.device
) * self.sigma
def compute_update(self, perturbations, fitnesses):
# Fitness-weighted average
weights = (fitnesses - fitnesses.mean()) / (
fitnesses.std() + 1e-8
)
return (weights[:, None] * perturbations).mean(dim=0)