API Reference¶
Complete API documentation for Eggroll Trainer.
Modules¶
- ESTrainer - Base class for Evolution Strategy trainers
- VanillaESTrainer - Vanilla ES with full-rank perturbations
- EGGROLLTrainer - EGGROLL algorithm with low-rank perturbations
Quick Import¶
Class Hierarchy¶
Common Patterns¶
Basic Usage¶
from eggroll_trainer import EGGROLLTrainer
trainer = EGGROLLTrainer(
model.parameters(),
model=model,
fitness_fn=fitness_fn,
population_size=256,
learning_rate=0.01,
sigma=0.1,
)
trainer.train(num_generations=100)
best_model = trainer.get_best_model()
Custom ES Algorithm¶
from eggroll_trainer import ESTrainer
class MyESTrainer(ESTrainer):
def sample_perturbations(self, population_size):
# Your implementation
pass
def compute_update(self, perturbations, fitnesses):
# Your implementation
pass
See Also¶
- User Guide - Usage guide
- Examples - Code examples
- Research - Algorithm details