diff --git a/README.md b/README.md index 553b34c..1ec31f8 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ import torch from torch import nn from mixture_of_experts import MoE -experts = MoE( +moe = MoE( dim = 512, num_experts = 16, # increase the experts (# parameters) of your model without increasing computation hidden_dim = 512 * 4, # size of hidden dimension in each expert, defaults to 4 * dimension @@ -30,7 +30,7 @@ experts = MoE( ) inputs = torch.randn(4, 1024, 512) -out, aux_loss = experts(inputs) # (4, 1024, 512), (1,) +out, aux_loss = moe(inputs) # (4, 1024, 512), (1,) ``` The above should suffice for a single machine, but if you want a heirarchical mixture of experts (2 levels), as used in the GShard paper, please follow the instructions below @@ -39,13 +39,48 @@ The above should suffice for a single machine, but if you want a heirarchical mi import torch from mixture_of_experts import HeirarchicalMoE -experts = HeirarchicalMoE( +moe = HeirarchicalMoE( dim = 512, num_experts = (4, 4), # 4 gates on the first layer, then 4 experts on the second, equaling 16 experts ) inputs = torch.randn(4, 1024, 512) -out, aux_loss = experts(inputs) # (4, 1024, 512), (1,) +out, aux_loss = moe(inputs) # (4, 1024, 512), (1,) +``` + +If you want some more sophisticated network for the experts, you can define your own and pass it into the `MoE` class as `experts` + +```python +import torch +from torch import nn +from mixture_of_experts import MoE + +# a 3 layered MLP as the experts + +class Experts(nn.Module): + def __init__(self, dim, num_experts = 16): + super().__init__() + self.w1 = nn.Parameter(torch.randn(num_experts, dim, dim * 4)) + self.w2 = nn.Parameter(torch.randn(num_experts, dim * 4, dim * 4)) + self.w3 = nn.Parameter(torch.randn(num_experts, dim * 4, dim)) + self.act = nn.LeakyReLU(inplace = True) + + def forward(self, x): + hidden1 = self.act(torch.einsum('end,edh->enh', x, self.w1)) + hidden2 = self.act(torch.einsum('end,edh->enh', hidden1, self.w2)) + out = torch.einsum('end,edh->enh', hidden2, self.w3) + return out + +experts = Experts(512, num_experts = 16) + +moe = MoE( + dim = 512, + num_experts = 16, + experts = experts +) + +inputs = torch.randn(4, 1024, 512) +out, aux_loss = moe(inputs) # (4, 1024, 512), (1,) ``` ## Citation diff --git a/mixture_of_experts/mixture_of_experts.py b/mixture_of_experts/mixture_of_experts.py index 2884f5a..0875c83 100644 --- a/mixture_of_experts/mixture_of_experts.py +++ b/mixture_of_experts/mixture_of_experts.py @@ -1,5 +1,6 @@ import torch from torch import nn +from inspect import isfunction import torch.nn.functional as F # constants @@ -9,6 +10,7 @@ # helper functions def default(val, default_val): + default_val = default_val() if isfunction(default_val) else default_val return val if val is not None else default_val def top1(t): @@ -26,7 +28,7 @@ def safe_one_hot(indexes, max_length): max_index = indexes.max() + 1 return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length] -# classes +# expert class class Experts(nn.Module): def __init__(self, @@ -49,6 +51,8 @@ def forward(self, x): out = torch.einsum('...nh,...hd->...nd', hidden, self.w2) return out +# gating network + class Top2Gating(nn.Module): def __init__( self, @@ -91,6 +95,9 @@ def forward(self, x, importance = None): raw_gates = torch.einsum('...bnd,...de->...bne', x, self.w_gating) raw_gates = raw_gates.softmax(dim=-1) + # FIND TOP 2 EXPERTS PER POSITON + # Find the top expert for each position. shape=[batch, group] + gate_1, index_1 = top1(raw_gates) mask_1 = F.one_hot(index_1, num_gates).float() density_1_proxy = raw_gates @@ -116,10 +123,16 @@ def forward(self, x, importance = None): gate_1 /= denom gate_2 /= denom + # BALANCING LOSSES + # shape = [batch, experts] + # We want to equalize the fraction of the batch assigned to each expert density_1 = mask_1.mean(dim=-2) + # Something continuous that is correlated with what we want to equalize. density_1_proxy = density_1_proxy.mean(dim=-2) loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2) + # Depending on the policy in the hparams, we may drop out some of the + # second-place experts. if policy == "all": pass elif policy == "none": @@ -131,17 +144,27 @@ def forward(self, x, importance = None): mask_2 *= (probs < (gate_2 / max(threshold, self.eps))).float().unsqueeze(-1) else: raise ValueError(f"Unknown policy {policy}") - + + # Each sequence sends (at most?) expert_capacity positions to each expert. + # Static expert_capacity dimension is needed for expert batch sizes expert_capacity = min(group_size, int((group_size * capacity_factor) / num_gates)) expert_capacity = max(expert_capacity, MIN_EXPERT_CAPACITY) expert_capacity_f = float(expert_capacity) + # COMPUTE ASSIGNMENT TO EXPERTS + # [batch, group, experts] + # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = cumsum_exclusive(mask_1) * mask_1 + # Remove the elements that don't fit. [batch, group, experts] mask_1 *= (position_in_expert_1 < expert_capacity_f).float() + # [batch, experts] + # How many examples in this sequence go to this expert mask_1_count = mask_1.sum(dim=-2, keepdim=True) + # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mask_1.sum(dim=-1) - + # [batch, group] position_in_expert_1 = position_in_expert_1.sum(dim=-1) + # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat position_in_expert_2 = cumsum_exclusive(mask_2) + mask_1_count @@ -152,6 +175,7 @@ def forward(self, x, importance = None): position_in_expert_2 = position_in_expert_2.sum(dim=-1) gate_2 *= mask_2_flat + # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1[..., None, None] * mask_1_flat[..., None, None] @@ -166,6 +190,8 @@ def forward(self, x, importance = None): dispatch_tensor = combine_tensor.bool().to(combine_tensor) return dispatch_tensor, combine_tensor, loss +# plain mixture of experts + class MoE(nn.Module): def __init__(self, dim, @@ -178,14 +204,15 @@ def __init__(self, second_threshold_eval = 0.2, capacity_factor_train = 1.25, capacity_factor_eval = 2., - loss_coef = 1e-2): + loss_coef = 1e-2, + experts = None): super().__init__() self.num_experts = num_experts gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval} self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs) - self.experts = Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation) + self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation)) self.loss_coef = loss_coef def forward(self, inputs): @@ -193,6 +220,7 @@ def forward(self, inputs): dispatch_tensor, combine_tensor, loss = self.gate(inputs) expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor) + # Now feed the expert inputs through the experts. orig_shape = expert_inputs.shape expert_inputs = expert_inputs.reshape(e, -1, d) expert_outputs = self.experts(expert_inputs) @@ -201,6 +229,8 @@ def forward(self, inputs): output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor) return output, loss * self.loss_coef +# 2-level heirarchical mixture of experts + class HeirarchicalMoE(nn.Module): def __init__(self, dim, @@ -213,7 +243,8 @@ def __init__(self, second_threshold_eval = 0.2, capacity_factor_train = 1.25, capacity_factor_eval = 2., - loss_coef = 1e-2): + loss_coef = 1e-2, + experts = None): super().__init__() assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now' @@ -226,7 +257,7 @@ def __init__(self, self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs) self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs) - self.experts = Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation) + self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation)) self.loss_coef = loss_coef def forward(self, inputs): @@ -234,17 +265,25 @@ def forward(self, inputs): dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs) expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer) + # we construct an "importance" Tensor for the inputs to the second-level + # gating. The importance of an input is 1.0 if it represents the + # first-choice expert-group and 0.5 if it represents the second-choice expert + # group. This is used by the second-level gating. importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1) importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float()) dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance) expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner) + # Now feed the expert inputs through the experts. orig_shape = expert_inputs.shape expert_inputs = expert_inputs.reshape(eo, ei, -1, d) expert_outputs = self.experts(expert_inputs) expert_outputs = expert_outputs.reshape(*orig_shape) + # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) + # expert_output has shape [y0, x1, h, d, n] + expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner) output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer) return output, (loss_outer + loss_inner) * self.loss_coef diff --git a/setup.py b/setup.py index b998d24..79761d4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'mixture-of-experts', packages = find_packages(), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'Sparsely-Gated Mixture of Experts for Pytorch', author = 'Phil Wang',