Skip to content

Commit

Permalink
add ability to pass in custom experts
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 17, 2020
1 parent 0849693 commit 3c27fa1
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 12 deletions.
43 changes: 39 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import torch
from torch import nn
from mixture_of_experts import MoE

experts = MoE(
moe = MoE(
dim = 512,
num_experts = 16, # increase the experts (# parameters) of your model without increasing computation
hidden_dim = 512 * 4, # size of hidden dimension in each expert, defaults to 4 * dimension
Expand All @@ -30,7 +30,7 @@ experts = MoE(
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = experts(inputs) # (4, 1024, 512), (1,)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
```

The above should suffice for a single machine, but if you want a heirarchical mixture of experts (2 levels), as used in the GShard paper, please follow the instructions below
Expand All @@ -39,13 +39,48 @@ The above should suffice for a single machine, but if you want a heirarchical mi
import torch
from mixture_of_experts import HeirarchicalMoE

experts = HeirarchicalMoE(
moe = HeirarchicalMoE(
dim = 512,
num_experts = (4, 4), # 4 gates on the first layer, then 4 experts on the second, equaling 16 experts
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = experts(inputs) # (4, 1024, 512), (1,)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
```

If you want some more sophisticated network for the experts, you can define your own and pass it into the `MoE` class as `experts`

```python
import torch
from torch import nn
from mixture_of_experts import MoE

# a 3 layered MLP as the experts

class Experts(nn.Module):
def __init__(self, dim, num_experts = 16):
super().__init__()
self.w1 = nn.Parameter(torch.randn(num_experts, dim, dim * 4))
self.w2 = nn.Parameter(torch.randn(num_experts, dim * 4, dim * 4))
self.w3 = nn.Parameter(torch.randn(num_experts, dim * 4, dim))
self.act = nn.LeakyReLU(inplace = True)

def forward(self, x):
hidden1 = self.act(torch.einsum('end,edh->enh', x, self.w1))
hidden2 = self.act(torch.einsum('end,edh->enh', hidden1, self.w2))
out = torch.einsum('end,edh->enh', hidden2, self.w3)
return out

experts = Experts(512, num_experts = 16)

moe = MoE(
dim = 512,
num_experts = 16,
experts = experts
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
```

## Citation
Expand Down
53 changes: 46 additions & 7 deletions mixture_of_experts/mixture_of_experts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
from inspect import isfunction
import torch.nn.functional as F

# constants
Expand All @@ -9,6 +10,7 @@
# helper functions

def default(val, default_val):
default_val = default_val() if isfunction(default_val) else default_val
return val if val is not None else default_val

def top1(t):
Expand All @@ -26,7 +28,7 @@ def safe_one_hot(indexes, max_length):
max_index = indexes.max() + 1
return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]

# classes
# expert class

class Experts(nn.Module):
def __init__(self,
Expand All @@ -49,6 +51,8 @@ def forward(self, x):
out = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
return out

# gating network

class Top2Gating(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -91,6 +95,9 @@ def forward(self, x, importance = None):
raw_gates = torch.einsum('...bnd,...de->...bne', x, self.w_gating)
raw_gates = raw_gates.softmax(dim=-1)

# FIND TOP 2 EXPERTS PER POSITON
# Find the top expert for each position. shape=[batch, group]

gate_1, index_1 = top1(raw_gates)
mask_1 = F.one_hot(index_1, num_gates).float()
density_1_proxy = raw_gates
Expand All @@ -116,10 +123,16 @@ def forward(self, x, importance = None):
gate_1 /= denom
gate_2 /= denom

# BALANCING LOSSES
# shape = [batch, experts]
# We want to equalize the fraction of the batch assigned to each expert
density_1 = mask_1.mean(dim=-2)
# Something continuous that is correlated with what we want to equalize.
density_1_proxy = density_1_proxy.mean(dim=-2)
loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2)

# Depending on the policy in the hparams, we may drop out some of the
# second-place experts.
if policy == "all":
pass
elif policy == "none":
Expand All @@ -131,17 +144,27 @@ def forward(self, x, importance = None):
mask_2 *= (probs < (gate_2 / max(threshold, self.eps))).float().unsqueeze(-1)
else:
raise ValueError(f"Unknown policy {policy}")


# Each sequence sends (at most?) expert_capacity positions to each expert.
# Static expert_capacity dimension is needed for expert batch sizes
expert_capacity = min(group_size, int((group_size * capacity_factor) / num_gates))
expert_capacity = max(expert_capacity, MIN_EXPERT_CAPACITY)
expert_capacity_f = float(expert_capacity)

# COMPUTE ASSIGNMENT TO EXPERTS
# [batch, group, experts]
# This is the position within the expert's mini-batch for this sequence
position_in_expert_1 = cumsum_exclusive(mask_1) * mask_1
# Remove the elements that don't fit. [batch, group, experts]
mask_1 *= (position_in_expert_1 < expert_capacity_f).float()
# [batch, experts]
# How many examples in this sequence go to this expert
mask_1_count = mask_1.sum(dim=-2, keepdim=True)
# [batch, group] - mostly ones, but zeros where something didn't fit
mask_1_flat = mask_1.sum(dim=-1)

# [batch, group]
position_in_expert_1 = position_in_expert_1.sum(dim=-1)
# Weight assigned to first expert. [batch, group]
gate_1 *= mask_1_flat

position_in_expert_2 = cumsum_exclusive(mask_2) + mask_1_count
Expand All @@ -152,6 +175,7 @@ def forward(self, x, importance = None):
position_in_expert_2 = position_in_expert_2.sum(dim=-1)
gate_2 *= mask_2_flat

# [batch, group, experts, expert_capacity]
combine_tensor = (
gate_1[..., None, None]
* mask_1_flat[..., None, None]
Expand All @@ -166,6 +190,8 @@ def forward(self, x, importance = None):
dispatch_tensor = combine_tensor.bool().to(combine_tensor)
return dispatch_tensor, combine_tensor, loss

# plain mixture of experts

class MoE(nn.Module):
def __init__(self,
dim,
Expand All @@ -178,21 +204,23 @@ def __init__(self,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2):
loss_coef = 1e-2,
experts = None):
super().__init__()

self.num_experts = num_experts

gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
self.experts = Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation)
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef

def forward(self, inputs):
b, n, d, e = *inputs.shape, self.num_experts
dispatch_tensor, combine_tensor, loss = self.gate(inputs)
expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)

# Now feed the expert inputs through the experts.
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(e, -1, d)
expert_outputs = self.experts(expert_inputs)
Expand All @@ -201,6 +229,8 @@ def forward(self, inputs):
output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
return output, loss * self.loss_coef

# 2-level heirarchical mixture of experts

class HeirarchicalMoE(nn.Module):
def __init__(self,
dim,
Expand All @@ -213,7 +243,8 @@ def __init__(self,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2):
loss_coef = 1e-2,
experts = None):
super().__init__()

assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now'
Expand All @@ -226,25 +257,33 @@ def __init__(self,
self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)

self.experts = Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation)
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef

def forward(self, inputs):
b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)

# we construct an "importance" Tensor for the inputs to the second-level
# gating. The importance of an input is 1.0 if it represents the
# first-choice expert-group and 0.5 if it represents the second-choice expert
# group. This is used by the second-level gating.
importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())

dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)

# Now feed the expert inputs through the experts.
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)

# NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
# expert_output has shape [y0, x1, h, d, n]

expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
return output, (loss_outer + loss_inner) * self.loss_coef
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mixture-of-experts',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Sparsely-Gated Mixture of Experts for Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 3c27fa1

Please sign in to comment.