-
-
Notifications
You must be signed in to change notification settings - Fork 46.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid log(0) in KL divergence #12233
Comments
This comment was marked as spam.
This comment was marked as spam.
…merator and denominator and added a test case
if y_true is 0 than what we have to return |
I left comments on your PR. Basically we can only use the entries where y_true is nonzero, without adding epsilon.
y_true is an array instead of a number here, so we can still use the remaining entries. |
|
is my solution is correct ? |
This comment was marked as off-topic.
This comment was marked as off-topic.
I think it's correct, and additionally handles the case where y_pred is 0. Great job. |
thanks sure |
Repository commit
03a4251
Python version (python --version)
Python 3.10.15
Dependencies version (pip freeze)
absl-py==2.1.0
astunparse==1.6.3
beautifulsoup4==4.12.3
certifi==2024.8.30
charset-normalizer==3.4.0
contourpy==1.3.0
cycler==0.12.1
dill==0.3.9
dom_toml==2.0.0
domdf-python-tools==3.9.0
fake-useragent==1.5.1
flatbuffers==24.3.25
fonttools==4.54.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.67.0
h5py==3.12.1
idna==3.10
imageio==2.36.0
joblib==1.4.2
keras==3.6.0
kiwisolver==1.4.7
libclang==18.1.1
lxml==5.3.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.9.2
mdurl==0.1.2
ml-dtypes==0.3.2
mpmath==1.3.0
namex==0.0.8
natsort==8.4.0
numpy==1.26.4
oauthlib==3.2.2
opencv-python==4.10.0.84
opt_einsum==3.4.0
optree==0.13.0
packaging==24.1
pandas==2.2.3
patsy==0.5.6
pbr==6.1.0
pillow==11.0.0
pip==24.2
protobuf==4.25.5
psutil==6.1.0
Pygments==2.18.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
pytz==2024.2
qiskit==1.2.4
qiskit-aer==0.15.1
requests==2.32.3
requests-oauthlib==1.3.1
rich==13.9.2
rustworkx==0.15.1
scikit-learn==1.5.2
scipy==1.14.1
setuptools==74.1.2
six==1.16.0
soupsieve==2.6
sphinx-pyproject==0.3.0
statsmodels==0.14.4
stevedore==5.3.0
symengine==0.13.0
sympy==1.13.3
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.2
tensorflow-io-gcs-filesystem==0.37.1
termcolor==2.5.0
threadpoolctl==3.5.0
tomli==2.0.2
tweepy==4.14.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
Werkzeug==3.0.4
wheel==0.44.0
wrapt==1.16.0
xgboost==2.1.1
Expected behavior
The entries where
y_true
is0
should be ignored in the summation (see Actual behavior)Actual behavior
In
Python/machine_learning/loss_functions.py
Lines 662 to 663 in 03a4251
y_true
is0
, the output ofnp.log
would become-inf
and thus the method returnsnan
.Maybe it would be better to exclude those entries where
y_true
is0
?The text was updated successfully, but these errors were encountered: