Skip to content

ESTrainer

Base class for Evolution Strategy trainers.

ESTrainer

ESTrainer(params, model, fitness_fn, population_size=50, learning_rate=0.01, sigma=0.1, device=None, seed=None)

Bases: Optimizer, ABC

Base class for Evolution Strategy trainers.

This class provides a framework for implementing various ES algorithms. Subclasses should override the abstract methods to implement specific ES variants (e.g., CMA-ES, OpenAI ES, Natural ES, etc.).

Example

class VanillaESTrainer(ESTrainer): def sample_perturbations(self, population_size): # Return perturbations for each member of the population pass

def compute_update(self, perturbations, fitnesses):
    # Compute parameter update from perturbations and fitnesses
    pass

Initialize the ES trainer.

Parameters:

Name Type Description Default
params

Iterable of parameters to optimize (for optimizer compatibility). Typically model.parameters().

required
model Module

PyTorch model to train. Parameters will be optimized.

required
fitness_fn Callable[[Module], float]

Function that takes a model and returns a fitness score (higher is better). Should handle model evaluation.

required
population_size int

Number of perturbed models to evaluate per generation

50
learning_rate float

Learning rate for parameter updates

0.01
sigma float

Standard deviation for parameter perturbations

0.1
device Optional[device]

Device to run training on (defaults to model's device)

None
seed Optional[int]

Random seed for reproducibility

None
Source code in eggroll_trainer/base.py
def __init__(
    self,
    params,
    model: nn.Module,
    fitness_fn: Callable[[nn.Module], float],
    population_size: int = 50,
    learning_rate: float = 0.01,
    sigma: float = 0.1,
    device: Optional[torch.device] = None,
    seed: Optional[int] = None,
):
    """
    Initialize the ES trainer.

    Args:
        params: Iterable of parameters to optimize (for optimizer compatibility).
               Typically model.parameters().
        model: PyTorch model to train. Parameters will be optimized.
        fitness_fn: Function that takes a model and returns a fitness score
                    (higher is better). Should handle model evaluation.
        population_size: Number of perturbed models to evaluate per generation
        learning_rate: Learning rate for parameter updates
        sigma: Standard deviation for parameter perturbations
        device: Device to run training on (defaults to model's device)
        seed: Random seed for reproducibility
    """
    defaults = dict(
        learning_rate=learning_rate,
        population_size=population_size,
        sigma=sigma,
    )
    super().__init__(params, defaults)

    self.model = model
    self.fitness_fn = fitness_fn
    self.population_size = population_size
    self.learning_rate = learning_rate
    self.sigma = sigma

    if device is None:
        self.device = next(model.parameters()).device
    else:
        self.device = device
        self.model = self.model.to(self.device)

    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    # Extract initial parameters
    self.initial_params = self._get_flat_params()
    self.current_params = self.initial_params.clone()

    # Training history
    self.generation = 0
    self.fitness_history: List[float] = []
    self.best_fitness: Optional[float] = None
    self.best_params: Optional[Tensor] = None

Functions

compute_update abstractmethod
compute_update(perturbations, fitnesses)

Compute parameter update from perturbations and fitnesses.

Parameters:

Name Type Description Default
perturbations Tensor

Tensor of shape (population_size, param_dim)

required
fitnesses Tensor

Tensor of shape (population_size,) containing fitness scores

required

Returns:

Type Description
Tensor

Tensor of shape (param_dim,) containing the parameter update

Source code in eggroll_trainer/base.py
@abc.abstractmethod
def compute_update(
    self,
    perturbations: Tensor,
    fitnesses: Tensor,
) -> Tensor:
    """
    Compute parameter update from perturbations and fitnesses.

    Args:
        perturbations: Tensor of shape (population_size, param_dim)
        fitnesses: Tensor of shape (population_size,) containing fitness scores

    Returns:
        Tensor of shape (param_dim,) containing the parameter update
    """
    pass
evaluate_fitness
evaluate_fitness(model)

Evaluate fitness of a model.

Parameters:

Name Type Description Default
model Module

Model to evaluate

required

Returns:

Type Description
float

Fitness score (higher is better)

Source code in eggroll_trainer/base.py
def evaluate_fitness(self, model: nn.Module) -> float:
    """
    Evaluate fitness of a model.

    Args:
        model: Model to evaluate

    Returns:
        Fitness score (higher is better)
    """
    return self.fitness_fn(model)
get_best_model
get_best_model()

Get a copy of the model with the best parameters found.

Returns:

Type Description
Module

A new model instance with the best parameters loaded

Source code in eggroll_trainer/base.py
def get_best_model(self) -> nn.Module:
    """
    Get a copy of the model with the best parameters found.

    Returns:
        A new model instance with the best parameters loaded
    """
    best_model = self._clone_model().to(self.device)

    if self.best_params is not None:
        # Temporarily save current params
        temp_params = self.current_params.clone()
        # Set best params to model
        self._set_flat_params(self.best_params)
        # Load best params into cloned model
        best_model.load_state_dict(self.model.state_dict())
        # Restore current params
        self._set_flat_params(temp_params)

    return best_model
reset
reset()

Reset trainer to initial state.

Source code in eggroll_trainer/base.py
def reset(self) -> None:
    """Reset trainer to initial state."""
    self.current_params = self.initial_params.clone()
    self._set_flat_params(self.current_params)
    self.generation = 0
    self.fitness_history = []
    self.best_fitness = None
    self.best_params = None
sample_perturbations abstractmethod
sample_perturbations(population_size)

Sample perturbations for the population.

Parameters:

Name Type Description Default
population_size int

Number of perturbations to sample

required

Returns:

Type Description
Tensor

Tensor of shape (population_size, param_dim) containing perturbations

Source code in eggroll_trainer/base.py
@abc.abstractmethod
def sample_perturbations(self, population_size: int) -> Tensor:
    """
    Sample perturbations for the population.

    Args:
        population_size: Number of perturbations to sample

    Returns:
        Tensor of shape (population_size, param_dim) containing perturbations
    """
    pass
step
step(closure=None)

Perform one optimization step (generation).

This method provides compatibility with PyTorch optimizer interface. For ES algorithms, the fitness function is provided at initialization, not per-step. The closure parameter is ignored for ES.

Parameters:

Name Type Description Default
closure

Optional callable (ignored for ES, fitness_fn used instead)

None

Returns:

Type Description

Dictionary containing training metrics

Source code in eggroll_trainer/base.py
def step(self, closure=None):
    """
    Perform one optimization step (generation).

    This method provides compatibility with PyTorch optimizer interface.
    For ES algorithms, the fitness function is provided at initialization,
    not per-step. The closure parameter is ignored for ES.

    Args:
        closure: Optional callable (ignored for ES, fitness_fn used instead)

    Returns:
        Dictionary containing training metrics
    """
    # Sample perturbations
    perturbations = self.sample_perturbations(self.population_size)

    # Evaluate population
    fitnesses = []
    for i in range(self.population_size):
        # Create perturbed parameters
        perturbed_params = self.current_params + self.sigma * perturbations[i]

        # Set model parameters
        self._set_flat_params(perturbed_params)

        # Evaluate fitness
        fitness = self.evaluate_fitness(self.model)
        fitnesses.append(fitness)

    fitnesses_tensor = torch.tensor(
        fitnesses,
        device=self.device,
        dtype=torch.float32,
    )

    # Compute update
    update = self.compute_update(perturbations, fitnesses_tensor)

    # Update parameters
    self.current_params = self.current_params + self.learning_rate * update
    self._set_flat_params(self.current_params)

    # Track best fitness
    best_fitness = fitnesses_tensor.max().item()
    if self.best_fitness is None or best_fitness > self.best_fitness:
        self.best_fitness = best_fitness
        self.best_params = self.current_params.clone()

    # Update history
    self.generation += 1
    mean_fitness = fitnesses_tensor.mean().item()
    self.fitness_history.append(mean_fitness)

    return {
        "generation": self.generation,
        "mean_fitness": mean_fitness,
        "best_fitness": best_fitness,
        "std_fitness": fitnesses_tensor.std().item(),
    }
train
train(num_generations, verbose=True)

Train the model for multiple generations.

Parameters:

Name Type Description Default
num_generations int

Number of generations to train

required
verbose bool

Whether to print progress

True

Returns:

Type Description
Dict[str, Any]

Dictionary containing final training state

Source code in eggroll_trainer/base.py
def train(self, num_generations: int, verbose: bool = True) -> Dict[str, Any]:
    """
    Train the model for multiple generations.

    Args:
        num_generations: Number of generations to train
        verbose: Whether to print progress

    Returns:
        Dictionary containing final training state
    """
    for gen in range(num_generations):
        metrics = self.step()

        if verbose:
            print(
                f"Generation {metrics['generation']}: "
                f"Mean Fitness = {metrics['mean_fitness']:.4f}, "
                f"Best Fitness = {metrics['best_fitness']:.4f}"
            )

    return {
        "generation": self.generation,
        "best_fitness": self.best_fitness,
        "fitness_history": self.fitness_history,
    }
zero_grad
zero_grad(set_to_none=False)

Zero gradients (no-op for ES algorithms).

This method exists for optimizer interface compatibility. ES algorithms don't use gradients, so this is a no-op.

Parameters:

Name Type Description Default
set_to_none bool

If True, set gradients to None (ignored for ES)

False
Source code in eggroll_trainer/base.py
def zero_grad(self, set_to_none: bool = False):
    """
    Zero gradients (no-op for ES algorithms).

    This method exists for optimizer interface compatibility.
    ES algorithms don't use gradients, so this is a no-op.

    Args:
        set_to_none: If True, set gradients to None (ignored for ES)
    """
    pass

Abstract Methods

Subclasses must implement:

sample_perturbations

def sample_perturbations(self, population_size: int) -> Tensor:
    """
    Sample perturbations for the population.

    Args:
        population_size: Number of population members

    Returns:
        Tensor of shape (population_size, param_dim) containing
        perturbations for each population member
    """
    pass

compute_update

def compute_update(
    self,
    perturbations: Tensor,
    fitnesses: Tensor,
) -> Tensor:
    """
    Compute parameter update from perturbations and fitnesses.

    Args:
        perturbations: Tensor of shape (population_size, param_dim)
        fitnesses: Tensor of shape (population_size,) with fitness scores

    Returns:
        Tensor of shape (param_dim,) containing parameter update
    """
    pass

Example

from eggroll_trainer import ESTrainer
import torch

class CustomESTrainer(ESTrainer):
    def sample_perturbations(self, population_size):
        param_dim = self.current_params.shape[0]
        return torch.randn(
            population_size,
            param_dim,
            device=self.device
        ) * self.sigma

    def compute_update(self, perturbations, fitnesses):
        # Fitness-weighted average
        weights = (fitnesses - fitnesses.mean()) / (
            fitnesses.std() + 1e-8
        )
        return (weights[:, None] * perturbations).mean(dim=0)