Skip to content

VanillaESTrainer

Vanilla Evolution Strategy trainer with full-rank Gaussian perturbations.

VanillaESTrainer

VanillaESTrainer(params, model, fitness_fn, population_size=50, learning_rate=0.01, sigma=0.1, device=None, seed=None)

Bases: ESTrainer

Vanilla Evolution Strategy trainer using Gaussian perturbations and fitness-weighted updates.

This is a basic implementation that can be used as a reference for creating more sophisticated ES algorithms.

Source code in eggroll_trainer/base.py
def __init__(
    self,
    params,
    model: nn.Module,
    fitness_fn: Callable[[nn.Module], float],
    population_size: int = 50,
    learning_rate: float = 0.01,
    sigma: float = 0.1,
    device: Optional[torch.device] = None,
    seed: Optional[int] = None,
):
    """
    Initialize the ES trainer.

    Args:
        params: Iterable of parameters to optimize (for optimizer compatibility).
               Typically model.parameters().
        model: PyTorch model to train. Parameters will be optimized.
        fitness_fn: Function that takes a model and returns a fitness score
                    (higher is better). Should handle model evaluation.
        population_size: Number of perturbed models to evaluate per generation
        learning_rate: Learning rate for parameter updates
        sigma: Standard deviation for parameter perturbations
        device: Device to run training on (defaults to model's device)
        seed: Random seed for reproducibility
    """
    defaults = dict(
        learning_rate=learning_rate,
        population_size=population_size,
        sigma=sigma,
    )
    super().__init__(params, defaults)

    self.model = model
    self.fitness_fn = fitness_fn
    self.population_size = population_size
    self.learning_rate = learning_rate
    self.sigma = sigma

    if device is None:
        self.device = next(model.parameters()).device
    else:
        self.device = device
        self.model = self.model.to(self.device)

    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    # Extract initial parameters
    self.initial_params = self._get_flat_params()
    self.current_params = self.initial_params.clone()

    # Training history
    self.generation = 0
    self.fitness_history: List[float] = []
    self.best_fitness: Optional[float] = None
    self.best_params: Optional[Tensor] = None

Functions

compute_update
compute_update(perturbations, fitnesses)

Compute fitness-weighted update.

Parameters:

Name Type Description Default
perturbations Tensor

Tensor of shape (population_size, param_dim)

required
fitnesses Tensor

Tensor of shape (population_size,) with fitness scores

required

Returns:

Type Description
Tensor

Tensor of shape (param_dim,) with parameter update

Source code in eggroll_trainer/vanilla.py
def compute_update(
    self,
    perturbations: Tensor,
    fitnesses: Tensor,
) -> Tensor:
    """
    Compute fitness-weighted update.

    Args:
        perturbations: Tensor of shape (population_size, param_dim)
        fitnesses: Tensor of shape (population_size,) with fitness scores

    Returns:
        Tensor of shape (param_dim,) with parameter update
    """
    # Normalize fitnesses (center around zero)
    normalized_fitnesses = fitnesses - fitnesses.mean()

    # Compute weighted average of perturbations
    # Higher fitness -> larger contribution to update
    weights = normalized_fitnesses / (normalized_fitnesses.std() + 1e-8)
    update = (weights[:, None] * perturbations).mean(dim=0)

    return update
sample_perturbations
sample_perturbations(population_size)

Sample Gaussian perturbations.

Parameters:

Name Type Description Default
population_size int

Number of perturbations to sample

required

Returns:

Type Description
Tensor

Tensor of shape (population_size, param_dim) with Gaussian noise

Source code in eggroll_trainer/vanilla.py
def sample_perturbations(self, population_size: int) -> Tensor:
    """
    Sample Gaussian perturbations.

    Args:
        population_size: Number of perturbations to sample

    Returns:
        Tensor of shape (population_size, param_dim) with Gaussian noise
    """
    param_dim = self.current_params.shape[0]
    return torch.randn(
        population_size,
        param_dim,
        device=self.device,
        dtype=self.current_params.dtype,
    )

Usage

from eggroll_trainer import VanillaESTrainer

trainer = VanillaESTrainer(
    model.parameters(),
    model=model,
    fitness_fn=fitness_fn,
    population_size=50,
    learning_rate=0.01,
    sigma=0.1,
    seed=42,
)

trainer.train(num_generations=100)

When to Use

  • Small models (< 10K parameters)
  • Understanding ES basics
  • Baseline comparisons

Characteristics

  • ✅ Simple and straightforward
  • ✅ Good for small models
  • ❌ Memory intensive for large models
  • ❌ Slower than EGGROLL for matrix parameters

See Also