eggroll-trainer
Python implementation of EGGROLL (Evolution Guided General Optimization via Low-rank Learning) for scalable evolution strategies training
I’ve been interested in evolution strategies for a while now—there’s something really appealing about black-box optimization that doesn’t need gradients. But the problem is that traditional ES gets prohibitively expensive when you’re dealing with large models. You end up needing (O(mn)) memory and computation per layer just to generate perturbations (E \in \mathbb{R}^{m \times n}), and that batched matrix multiplication kills you.
So when I saw the EGGROLL poast come out, I got pretty excited. The idea is simple but clever: instead of generating full-rank perturbations, you generate two smaller matrices (A \in \mathbb{R}^{m \times r}) and (B \in \mathbb{R}^{n \times r}) where (r \ll \min(m, n)), and then form your perturbation as (AB^\top). This drops your memory from (O(mn)) down to (O(r(m+n))) and your computation from (O(mn)) to (O(r(m+n))).
The cool part is that even though each perturbation is low-rank, when you average across a population of (N) workers, you still get high-rank updates. The math says it converges to the full-rank update at (O\left(\frac{1}{r}\right)) rate, which is pretty fast.
Here’s how standard ES works (OpenAI’s formulation):
[ \nabla_{\theta}\mathbb{E}{\epsilon\sim N(0,I)} F(\theta+\sigma\epsilon) = \frac{1}{\sigma}\mathbb{E}{\epsilon\sim N(0,I)}{F(\theta+\sigma\epsilon)\epsilon} ]
where (F) is your fitness function, (\theta) are the parameters you’re optimizing, (\sigma) is the noise scale, and (\epsilon) gets sampled independently for each parameter. The problem is that for a matrix layer with parameters (W \in \mathbb{R}^{m \times n}), you’re sampling a full (m \times n) matrix of noise.
With EGGROLL, you sample those two smaller matrices (A) and (B), and your forward pass becomes:
[ x \cdot W^\top + x \cdot B \cdot A^\top \cdot \sigma ]
This is way more efficient to batch, and according to the paper, you get something like a 100x speedup for billion-parameter models. It’s basically as fast as pure batch inference (less than 10% slower), which is wild.
I decided to implement this in Python because I wanted to actually use it. The library is pretty straightforward—you define your model, write a fitness function, and let it rip.
Here’s a simple example from the repo:
from eggroll_trainer import EGGROLLTrainer
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self, input_dim=10, hidden_dim=20, output_dim=1):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def create_fitness_fn(target_params=None, noise_std=0.1):
"""Fitness function that rewards models close to target parameters."""
if target_params is None:
model = SimpleModel()
target_params = {}
for name, param in model.named_parameters():
target_params[name] = param.data.clone()
def fitness_fn(model):
total_distance = 0.0
for name, param in model.named_parameters():
if name in target_params:
distance = torch.norm(param - target_params[name])
total_distance += distance.item()
noise = torch.randn(1).item() * noise_std
return -total_distance + noise
return fitness_fn
# Create model and fitness function
model = SimpleModel()
fitness_fn = create_fitness_fn()
# Initialize EGGROLL trainer
trainer = EGGROLLTrainer(
model=model,
fitness_fn=fitness_fn,
population_size=32,
learning_rate=0.01,
sigma=0.1,
rank=1, # Low-rank rank
)
# Train!
trainer.train(num_generations=1000)
For something more practical, there’s an MNIST example that trains a CNN classifier:
from eggroll_trainer import EGGROLLTrainer
from torchvision import datasets, transforms
import torch.nn.functional as F
class MNISTNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
def create_fitness_fn(train_loader, device, batch_limit=20):
"""Fitness based on accuracy."""
cached_data = []
cached_targets = []
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx >= batch_limit:
break
cached_data.append(data.to(device))
cached_targets.append(target.to(device))
def fitness_fn(model):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in zip(cached_data, cached_targets):
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return correct / total if total > 0 else 0.0
return fitness_fn
# Load data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
model = MNISTNet()
fitness_fn = create_fitness_fn(train_loader, device='cuda')
trainer = EGGROLLTrainer(
model=model,
fitness_fn=fitness_fn,
population_size=32,
rank=1,
learning_rate=0.01,
sigma=0.1
)
trainer.train(num_generations=100)
The thing I really like about this is that it basically eliminates the barrier between inference and training. If you can do batched inference and you can define a fitness function, you can optimize your system. No gradients needed, no differentiable objectives required. That opens up a lot of possibilities—non-differentiable components, RL scenarios, weird architectures that backprop doesn’t play nice with.
I’ve got more examples in the repo including comparisons with standard ES and a full comparison framework. Check out the original paper/website for the full details—they did some really cool experiments including training pure integer RNNs with no activation functions, which is only feasible because of how fast EGGROLL is.
The code is pretty modular—there’s a base class, the core EGGROLL implementation, and a simplified API if you just want to get started quickly. Contributions welcome if you want to add features or examples!
EGGROLL Algorithm: Given a neural network with parameters \(\theta\), fitness function \(F\), population size \(N\), rank \(r\), and perturbation scale \(\sigma\):
- Initialization: Initialize base parameters \(\theta_0\)
- For each generation \(t\):
- For each population member \(i \in \{1, \ldots, N\}\):
- For each matrix layer \(W \in \mathbb{R}^{m \times n}\):
- Sample \(A_i \in \mathbb{R}^{m \times r}\), \(B_i \in \mathbb{R}^{n \times r}\) from \(\mathcal{N}(0, I)\)
- Form low-rank perturbation: \(\Delta W_i = A_i B_i^\top\)
- Compute perturbed parameters: \(\theta_i = \theta_t + \sigma \cdot \Delta_i\)
- Evaluate fitness: \(f_i = F(\theta_i)\)
- Update: \(\theta_{t+1} = \theta_t + \alpha \cdot \frac{1}{N\sigma} \sum_{i=1}^N f_i \cdot \Delta_i\)
Complexity Analysis:
- Memory: \(O(r(m+n))\) per layer vs. \(O(mn)\) for full-rank ES
- Computation: \(O(r(m+n))\) per forward pass vs. \(O(mn)\) for full-rank ES
- Convergence: Low-rank updates converge to full-rank at rate \(O(1/r)\)
Key Properties:
- Maintains high-rank updates through population averaging despite low-rank perturbations
- Scales linearly with rank \(r\) rather than quadratically with layer dimensions
- Enables efficient batched computation compatible with modern GPU architectures
- No gradient computation required—pure black-box optimization