diff --git a/Project.toml b/Project.toml index 617f6e56c..2830de211 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.56" +version = "0.10.57" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/chainrules.jl b/src/chainrules.jl index 3b52860dd..eebdf95b5 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -150,8 +150,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * - "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * - "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * + "or because some external computation has acted on `ColVecs` to produce a vector of vectors." * + "In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." * + "In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," * + " rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it." ) end return ColVecs(X), ColVecs_pullback @@ -162,8 +165,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) return error( "Pullback on AbstractVector{<:AbstractVector}.\n" * - "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * - "To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", + "This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" * + "or because some external computation has acted on `RowVecs` to produce a vector of vectors." * + "If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", ) end return RowVecs(X), RowVecs_pullback