Skip to content

API Reference

Complete API documentation for Eggroll Trainer.

Modules

Quick Import

from eggroll_trainer import ESTrainer, VanillaESTrainer, EGGROLLTrainer

Class Hierarchy

ESTrainer (abstract)
├── VanillaESTrainer
└── EGGROLLTrainer

Common Patterns

Basic Usage

from eggroll_trainer import EGGROLLTrainer

trainer = EGGROLLTrainer(
    model.parameters(),
    model=model,
    fitness_fn=fitness_fn,
    population_size=256,
    learning_rate=0.01,
    sigma=0.1,
)

trainer.train(num_generations=100)
best_model = trainer.get_best_model()

Custom ES Algorithm

from eggroll_trainer import ESTrainer

class MyESTrainer(ESTrainer):
    def sample_perturbations(self, population_size):
        # Your implementation
        pass

    def compute_update(self, perturbations, fitnesses):
        # Your implementation
        pass

See Also