From 807e3f9f52a20d2f5d5688f14bfca8f9c4157e2a Mon Sep 17 00:00:00 2001 From: KGrewal1 <45569241+KGrewal1@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:23:45 +0100 Subject: [PATCH] derivative for GELU (#1160) * derivative for GELU * add tests --- candle-core/src/backprop.rs | 10 +++++++++- candle-core/tests/grad_tests.rs | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index dfad5f6284..7488d93979 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -471,7 +471,15 @@ impl Tensor { Op::Unary(_, UnaryOp::Round) => { Err(Error::BackwardNotSupported { op: "round" })? } - Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(arg, UnaryOp::Gelu) => { + let sum_grad = grads.or_insert(arg)?; + let cube = arg.powf(3.)?; + let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?; + let gelu_grad = (((0.5 * &tanh)? + + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))? + + 0.5)?; + *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)? + } Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, Op::Unary(_, UnaryOp::GeluErf) => { Err(Error::BackwardNotSupported { op: "gelu-erf" })? diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 2a70cfc4e8..bcfe639fc7 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -192,6 +192,19 @@ fn unary_grad(device: &Device) -> Result<()> { test_utils::to_vec1_round(grad_x, 2)?, [0.01, 0.42, 0.0, 0.98], ); + + // testing compared to pytorch nn.GELU(approximate = 'tanh') + let y = x.gelu()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.9964, 0.8412, 3.9999, 0.0839] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [1.0116, 1.0830, 1.0003, 0.6188], + ); Ok(()) }