Skip to content

EGGROLLTrainer

EGGROLL (Evolution Guided General Optimization via Low-rank Learning) trainer.

EGGROLLTrainer

EGGROLLTrainer(params, model, fitness_fn, population_size=256, learning_rate=0.01, sigma=0.1, rank=1, noise_reuse=0, group_size=0, freeze_nonlora=False, device=None, seed=None)

Bases: Optimizer

EGGROLL trainer implementing the actual EGGROLL algorithm.

Unlike the base ESTrainer which works with flattened parameters, EGGROLL works per-layer with low-rank perturbations for efficiency.

Key features: - Low-rank perturbations: For matrices W ∈ R^(m×n), samples A ∈ R^(m×r), B ∈ R^(n×r) where r << min(m,n), forming perturbation A @ B.T - Per-layer updates: Handles each parameter tensor independently - Noise reuse: Can reuse noise across multiple evaluations (antithetic sampling) - Group normalization: Supports fitness normalization within groups

Subclasses torch.optim.Optimizer for compatibility with PyTorch optimizer interface. Use model.parameters() as the first argument, similar to standard optimizers.

Initialize the EGGROLL trainer.

Parameters:

Name Type Description Default
params

Iterable of parameters to optimize (for optimizer compatibility). Typically model.parameters().

required
model Module

PyTorch model to train. Parameters will be optimized.

required
fitness_fn Callable[[Module], float]

Function that takes a model and returns a fitness score (higher is better). Should handle model evaluation.

required
population_size int

Number of population members

256
learning_rate float

Learning rate for parameter updates

0.01
sigma float

Standard deviation for perturbations

0.1
rank int

Rank of low-rank perturbations (default: 1)

1
noise_reuse int

Number of evaluations to reuse noise (0 = no reuse, 2 = antithetic)

0
group_size int

Size of groups for fitness normalization (0 = global normalization)

0
freeze_nonlora bool

If True, only apply LoRA updates to linear layers

False
device Optional[device]

Device to run on

None
seed Optional[int]

Random seed

None
Source code in eggroll_trainer/eggroll.py
def __init__(
    self,
    params,
    model: nn.Module,
    fitness_fn: Callable[[nn.Module], float],
    population_size: int = 256,
    learning_rate: float = 0.01,
    sigma: float = 0.1,
    rank: int = 1,
    noise_reuse: int = 0,
    group_size: int = 0,
    freeze_nonlora: bool = False,
    device: Optional[torch.device] = None,
    seed: Optional[int] = None,
):
    """
    Initialize the EGGROLL trainer.

    Args:
        params: Iterable of parameters to optimize (for optimizer compatibility).
               Typically model.parameters().
        model: PyTorch model to train. Parameters will be optimized.
        fitness_fn: Function that takes a model and returns a fitness score
                    (higher is better). Should handle model evaluation.
        population_size: Number of population members
        learning_rate: Learning rate for parameter updates
        sigma: Standard deviation for perturbations
        rank: Rank of low-rank perturbations (default: 1)
        noise_reuse: Number of evaluations to reuse noise (0 = no reuse, 2 = antithetic)
        group_size: Size of groups for fitness normalization (0 = global normalization)
        freeze_nonlora: If True, only apply LoRA updates to linear layers
        device: Device to run on
        seed: Random seed
    """
    defaults = dict(
        learning_rate=learning_rate,
        population_size=population_size,
        sigma=sigma,
        rank=rank,
        noise_reuse=noise_reuse,
        group_size=group_size,
        freeze_nonlora=freeze_nonlora,
    )
    super().__init__(params, defaults)

    self.model = model
    self.fitness_fn = fitness_fn
    self.population_size = population_size
    self.learning_rate = learning_rate
    self.sigma = sigma
    self.rank = rank
    self.noise_reuse = noise_reuse
    self.group_size = group_size
    self.freeze_nonlora = freeze_nonlora

    if device is None:
        self.device = next(model.parameters()).device
    else:
        self.device = device
        self.model = self.model.to(self.device)

    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    # Extract parameter structure
    self.param_names = []
    self.param_shapes = []
    self.param_dims = []
    self.is_matrix = []  # True for 2D tensors (can use LoRA)

    for name, param in self.model.named_parameters():
        if param.requires_grad:
            self.param_names.append(name)
            self.param_shapes.append(param.shape)
            self.param_dims.append(param.numel())
            # Matrices (2D tensors) can use low-rank updates
            self.is_matrix.append(len(param.shape) == 2)

    # Initialize optimizer state (we'll use SGD-style updates)
    self.optimizer_state = {}

    # Training state
    self.generation = 0
    self.fitness_history: List[float] = []
    self.best_fitness: Optional[float] = None
    self.best_state_dict: Optional[Dict] = None

    # PRNG state for noise generation
    self.rng = torch.Generator(device=self.device)
    if seed is not None:
        self.rng.manual_seed(seed)

Functions

get_best_model
get_best_model()

Get a copy of the model with the best parameters found.

Returns:

Type Description
Module

New model instance with best parameters

Source code in eggroll_trainer/eggroll.py
def get_best_model(self) -> nn.Module:
    """
    Get a copy of the model with the best parameters found.

    Returns:
        New model instance with best parameters
    """
    best_model = copy.deepcopy(self.model)
    if self.best_state_dict is not None:
        best_model.load_state_dict(self.best_state_dict)
    return best_model
step
step(closure=None)

Perform one optimization step.

This method provides compatibility with PyTorch optimizer interface. For ES algorithms, the fitness function is provided at initialization, not per-step. The closure parameter is ignored for ES.

Parameters:

Name Type Description Default
closure

Optional callable (ignored for ES, fitness_fn used instead)

None

Returns:

Type Description

Dictionary with training metrics

Source code in eggroll_trainer/eggroll.py
def step(self, closure=None):
    """
    Perform one optimization step.

    This method provides compatibility with PyTorch optimizer interface.
    For ES algorithms, the fitness function is provided at initialization,
    not per-step. The closure parameter is ignored for ES.

    Args:
        closure: Optional callable (ignored for ES, fitness_fn used instead)

    Returns:
        Dictionary with training metrics
    """
    # Evaluate population
    raw_scores = []

    # Save original state
    original_state_dict = {name: param.data.clone() for name, param in self.model.named_parameters()}

    for thread_id in range(self.population_size):
        # Apply perturbations
        perturbed_params = self._apply_perturbations(thread_id, self.generation)

        # Set model parameters
        for name, value in perturbed_params.items():
            param_ref = dict(self.model.named_parameters())[name]
            param_ref.data.copy_(value)

        # Evaluate fitness
        fitness = self.fitness_fn(self.model)
        raw_scores.append(fitness)

    # Restore original parameters
    for name, value in original_state_dict.items():
        param_ref = dict(self.model.named_parameters())[name]
        param_ref.data.copy_(value)

    raw_scores_tensor = torch.tensor(
        raw_scores,
        device=self.device,
        dtype=torch.float32,
    )

    # Convert to normalized fitnesses
    fitnesses = self._convert_fitnesses(raw_scores_tensor)

    # Compute updates for each parameter
    updates = {}
    for name, param in self.model.named_parameters():
        if not param.requires_grad:
            continue

        idx = self.param_names.index(name)
        is_mat = self.is_matrix[idx]

        if is_mat:
            update = self._compute_lora_update(param, fitnesses, self.generation)
        else:
            update = self._compute_full_update(param, fitnesses, self.generation)

        updates[name] = update

    # Apply updates
    for name, param in self.model.named_parameters():
        if name in updates:
            # Scale by learning rate and population size
            scale = self.learning_rate * (self.population_size ** 0.5)
            param.data = param.data + scale * updates[name]

    # Track best fitness
    best_fitness = raw_scores_tensor.max().item()
    if self.best_fitness is None or best_fitness > self.best_fitness:
        self.best_fitness = best_fitness
        self.best_state_dict = copy.deepcopy(self.model.state_dict())

    # Update history
    self.generation += 1
    mean_fitness = raw_scores_tensor.mean().item()
    self.fitness_history.append(mean_fitness)

    return {
        "generation": self.generation,
        "mean_fitness": mean_fitness,
        "best_fitness": best_fitness,
        "std_fitness": raw_scores_tensor.std().item(),
    }
train
train(num_generations, verbose=True)

Train for multiple generations.

Parameters:

Name Type Description Default
num_generations int

Number of generations to train

required
verbose bool

Whether to print progress

True

Returns:

Type Description
Dict[str, Any]

Dictionary with final training state

Source code in eggroll_trainer/eggroll.py
def train(self, num_generations: int, verbose: bool = True) -> Dict[str, Any]:
    """
    Train for multiple generations.

    Args:
        num_generations: Number of generations to train
        verbose: Whether to print progress

    Returns:
        Dictionary with final training state
    """
    for gen in range(num_generations):
        metrics = self.step()

        if verbose:
            print(
                f"Generation {metrics['generation']}: "
                f"Mean Fitness = {metrics['mean_fitness']:.4f}, "
                f"Best Fitness = {metrics['best_fitness']:.4f}"
            )

    return {
        "generation": self.generation,
        "best_fitness": self.best_fitness,
        "fitness_history": self.fitness_history,
    }
zero_grad
zero_grad(set_to_none=False)

Zero gradients (no-op for ES algorithms).

This method exists for optimizer interface compatibility. ES algorithms don't use gradients, so this is a no-op.

Parameters:

Name Type Description Default
set_to_none bool

If True, set gradients to None (ignored for ES)

False
Source code in eggroll_trainer/eggroll.py
def zero_grad(self, set_to_none: bool = False):
    """
    Zero gradients (no-op for ES algorithms).

    This method exists for optimizer interface compatibility.
    ES algorithms don't use gradients, so this is a no-op.

    Args:
        set_to_none: If True, set gradients to None (ignored for ES)
    """
    pass

Usage

from eggroll_trainer import EGGROLLTrainer

trainer = EGGROLLTrainer(
    model.parameters(),
    model=model,
    fitness_fn=fitness_fn,
    population_size=256,
    learning_rate=0.01,
    sigma=0.1,
    rank=1,
    noise_reuse=0,
    group_size=0,
    freeze_nonlora=False,
    seed=42,
)

trainer.train(num_generations=100)

Key Parameters

rank (int, default: 1)

Rank of low-rank perturbations. Controls memory/computation tradeoff:

  • rank=1: Minimum memory, fastest (recommended)
  • rank=2-4: Better expressivity, still efficient
  • rank>>1: Approaches full-rank (not recommended)

noise_reuse (int, default: 0)

Number of evaluations to reuse noise:

  • 0: No reuse (standard)
  • 2: Antithetic sampling (use +ε and -ε)
  • >2: Multiple reuses (rarely needed)

group_size (int, default: 0)

Size of groups for fitness normalization:

  • 0: Global normalization (all population members)
  • >0: Group-based normalization (can improve stability)

freeze_nonlora (bool, default: False)

If True, only apply LoRA updates to 2D parameters (matrices):

  • False: Update all parameters (recommended)
  • True: Only update matrix parameters (biases frozen)

Characteristics

  • 100x speedup over full-rank for large models
  • ✅ Memory efficient
  • ✅ Handles large population sizes
  • ✅ Per-layer updates
  • ✅ Supports fitness normalization

See Also