-
Hi, I'm trying to train a neural net with two sub-networks a simple MLP with one hidden layer and a projection layer (a linear layer) and am optimising it with respect to a custom loss function. I'm using mlx.nn.value_and_grad to get a function which returns the loss and gradient. When I then use this function I am getting the following error: 'ValueError: Primitive's vjp not implemented.' I'm not too sure where to start looking in order to debug this, can anyone suggest where might be best to look first? The reason for the weird network structure and loss is I'm doing self-supervised learning and I'm trying to learn a representation over two variables. My model is as follows: class MLP(nn.Module):
def __init__(
self, num_layers: int, g_input_dim: int, g_hidden_dim: int, g_output_dim: int, f_input_dim: int, f_output_dim: int
):
super().__init__()
layer_sizes = [g_input_dim] + [g_hidden_dim] * num_layers + [g_output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
self.projection = nn.Linear(f_input_dim, f_output_dim)
def __call__(self, X, t, y, neg):
input = mx.concatenate([X, t], axis=1)
for l in self.layers[:-1]:
input = mx.maximum(l(input), 0.0)
output = self.layers[-1](input)
toProj_ys = mx.concatenate([output, y], axis=1)
output_ys = self.projection(toProj_ys)
output_neg = mx.zeros((neg.shape[0], output_ys.shape[0], output_ys.shape[1]))
for i in range(neg.shape[0]):
iter = mx.array([neg[i].tolist() for _ in range(output.shape[0])])
toProj_neg = mx.concatenate([output, iter], axis=1)
output_neg[i] = self.projection(toProj_neg)
return output_ys, output_neg And my loss: def loss(model, x, t, y, neg):
y_score, neg_score = model(x, t, y, neg)
# Check below
sum_neg_score = mx.sum(neg_score, axis=0)
loss = - mx.logsumexp(y_score / sum_neg_score)
return mx.mean(loss) There are perhaps some things I may need to change with the code to get it to have decent performance, but, I don't see anything glaring that may be causing this error. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
From looking at it, I think it is this line:
I believe this is implemented with a So
|
Beta Was this translation helpful? Give feedback.
-
Let me know if that's not the issue and I will reopen this. |
Beta Was this translation helpful? Give feedback.
-
That fixed it! thank you very much. |
Beta Was this translation helpful? Give feedback.
From looking at it, I think it is this line:
I believe this is implemented with a
Scatter
which does not have avjp
yet. You can workaround that by usingconcatenate
for now.So