Quick Start¶
Get started with Eggroll Trainer in just a few minutes!
Minimal Example¶
Here's a complete example that trains a simple model:
import torch
import torch.nn as nn
from eggroll_trainer import EGGROLLTrainer
# 1. Define your model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))
# 2. Define a fitness function (higher is better)
def fitness_fn(model):
# Example: evaluate on random data
x = torch.randn(32, 10)
y_pred = model(x)
# Simple fitness: maximize output magnitude
return y_pred.abs().mean().item()
# 3. Create trainer
model = SimpleModel()
trainer = EGGROLLTrainer(
model.parameters(),
model=model,
fitness_fn=fitness_fn,
population_size=256,
learning_rate=0.01,
sigma=0.1,
rank=1,
seed=42,
)
# 4. Train!
trainer.train(num_generations=100)
# 5. Get best model
best_model = trainer.get_best_model()
What You Can Build¶
Eggroll Trainer works great for reinforcement learning in 3D environments:
Understanding the Code¶
Model¶
Any PyTorch nn.Module works! The trainer will optimize all trainable parameters.
Fitness Function¶
The fitness function takes a model and returns a scalar score. Higher is better. For loss minimization, return -loss.
def fitness_fn(model):
# Evaluate model on your task
loss = compute_loss(model)
return -loss # Convert to maximization
Trainer Parameters¶
population_size: Number of perturbed models evaluated per generation (larger = better but slower)learning_rate: Step size for parameter updatessigma: Standard deviation of perturbationsrank: Rank of low-rank perturbations (1 is often sufficient)
Running the Example¶
Save the code above to quick_start.py and run:
You should see output like:
Training EGGROLLTrainer...
Generation 0: Mean fitness = 0.1234
Generation 10: Mean fitness = 0.2345
...
Generation 100: Mean fitness = 0.4567
Training complete!
Next Steps¶
- Learn about Core Concepts
- See more Examples
- Read the User Guide
Common Patterns¶
Classification Task¶
def classification_fitness(model, data_loader):
correct = 0
total = 0
for x, y in data_loader:
logits = model(x)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += len(y)
accuracy = correct / total
return accuracy # Higher is better
Regression Task¶
def regression_fitness(model, data_loader):
total_loss = 0
count = 0
for x, y in data_loader:
y_pred = model(x)
loss = nn.functional.mse_loss(y_pred, y)
total_loss += loss.item()
count += 1
avg_loss = total_loss / count
return -avg_loss # Convert loss to fitness
Parameter Matching¶
def parameter_fitness(model, target_params):
current_params = torch.cat([p.flatten() for p in model.parameters()])
distance = (current_params - target_params).norm()
return -distance.item() # Minimize distance
Tips¶
- Start with
population_size=256- EGGROLL is efficient with large populations - Use
rank=1- Often sufficient and fastest - Tune
sigma- Start with 0.1, adjust based on your problem scale - Monitor fitness - Use
trainer.historyto track progress
Getting Help¶
- Check the API Reference
- See Examples
- Open an issue on GitHub