-
-
Notifications
You must be signed in to change notification settings - Fork 622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed ndcg #3054
base: master
Are you sure you want to change the base?
Distributed ndcg #3054
Changes from 6 commits
3714ebf
d673031
a34fe2c
a993929
3ce5a5d
6e0850b
3964133
3920b2e
6a4bc7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ignite.metrics.recsys.ndcg import NDCG | ||
|
||
__all__ = [ | ||
"NDCG", | ||
] |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,116 @@ | ||||
from typing import Callable, Optional, Sequence, Union | ||||
|
||||
import torch | ||||
|
||||
from ignite.exceptions import NotComputableError | ||||
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||||
|
||||
__all__ = ["NDCG"] | ||||
|
||||
|
||||
def _tie_averaged_dcg( | ||||
y_pred: torch.Tensor, | ||||
y_true: torch.Tensor, | ||||
discount_cumsum: torch.Tensor, | ||||
device: Union[str, torch.device] = torch.device("cpu"), | ||||
) -> torch.Tensor: | ||||
_, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True) | ||||
ranked = torch.zeros(counts.shape[0]).to(device) | ||||
ranked.index_put_([inv], y_true, accumulate=True) | ||||
ranked /= counts | ||||
groups = torch.cumsum(counts, dim=-1) - 1 | ||||
discount_sums = torch.empty(counts.shape[0]).to(device) | ||||
discount_sums[0] = discount_cumsum[groups[0]] | ||||
discount_sums[1:] = torch.diff(discount_cumsum[groups]) | ||||
|
||||
return torch.sum(torch.mul(ranked, discount_sums)) | ||||
|
||||
|
||||
def _dcg_sample_scores( | ||||
y_pred: torch.Tensor, | ||||
y_true: torch.Tensor, | ||||
k: Optional[int] = None, | ||||
log_base: Union[int, float] = 2, | ||||
ignore_ties: bool = False, | ||||
device: Union[str, torch.device] = torch.device("cpu"), | ||||
) -> torch.Tensor: | ||||
discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2) | ||||
discount = discount.to(device) | ||||
|
||||
if k is not None: | ||||
discount[k:] = 0.0 | ||||
|
||||
if ignore_ties: | ||||
ranking = torch.argsort(y_pred, descending=True) | ||||
ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device) | ||||
discounted_gains = torch.mm(ranked, discount.reshape(-1, 1)) | ||||
|
||||
else: | ||||
discount_cumsum = torch.cumsum(discount, dim=-1) | ||||
discounted_gains = torch.tensor( | ||||
[_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device | ||||
) | ||||
|
||||
return discounted_gains | ||||
|
||||
|
||||
def _ndcg_sample_scores( | ||||
y_pred: torch.Tensor, | ||||
y_true: torch.Tensor, | ||||
k: Optional[int] = None, | ||||
log_base: Union[int, float] = 2, | ||||
ignore_ties: bool = False, | ||||
) -> torch.Tensor: | ||||
device = y_true.device | ||||
gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device) | ||||
if not ignore_ties: | ||||
gain = gain.unsqueeze(dim=-1) | ||||
normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device) | ||||
all_relevant = normalizing_gain != 0 | ||||
normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant] | ||||
return normalized_gain | ||||
|
||||
|
||||
class NDCG(Metric): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, we need to write a docstring like here: ignite/ignite/metrics/accuracy.py Line 94 in 34a707e
Please read this section of contributing guide: https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#writing-documentation, especially about |
||||
def __init__( | ||||
self, | ||||
output_transform: Callable = lambda x: x, | ||||
device: Union[str, torch.device] = torch.device("cpu"), | ||||
k: Optional[int] = None, | ||||
log_base: Union[int, float] = 2, | ||||
exponential: bool = False, | ||||
ignore_ties: bool = False, | ||||
): | ||||
if log_base == 1 or log_base <= 0: | ||||
raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}") | ||||
self.log_base = log_base | ||||
self.k = k | ||||
self.exponential = exponential | ||||
self.ignore_ties = ignore_ties | ||||
super(NDCG, self).__init__(output_transform=output_transform, device=device) | ||||
|
||||
@reinit__is_reduced | ||||
def reset(self) -> None: | ||||
self.num_examples = 0 | ||||
self.ndcg = torch.tensor(0.0, device=self._device) | ||||
|
||||
@reinit__is_reduced | ||||
def update(self, output: Sequence[torch.Tensor]) -> None: | ||||
y_pred, y_true = output[0].detach(), output[1].detach() | ||||
|
||||
y_pred = y_pred.to(torch.float32).to(self._device) | ||||
y_true = y_true.to(torch.float32).to(self._device) | ||||
|
||||
if self.exponential: | ||||
y_true = 2**y_true - 1 | ||||
|
||||
gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties) | ||||
self.ndcg += torch.sum(gain) | ||||
self.num_examples += y_pred.shape[0] | ||||
|
||||
@sync_all_reduce("ndcg", "num_examples") | ||||
def compute(self) -> float: | ||||
if self.num_examples == 0: | ||||
raise NotComputableError("NGCD must have at least one example before it can be computed.") | ||||
|
||||
return (self.ndcg / self.num_examples).item() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from sklearn.metrics import ndcg_score | ||
from sklearn.metrics._ranking import _dcg_sample_scores | ||
|
||
import ignite.distributed as idist | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.recsys.ndcg import NDCG | ||
|
||
|
||
@pytest.fixture(params=[item for item in range(6)]) | ||
def test_case(request): | ||
return [ | ||
(torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), | ||
( | ||
torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]), | ||
torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]), | ||
), | ||
][request.param % 2] | ||
|
||
|
||
@pytest.mark.parametrize("k", [None, 2, 3]) | ||
@pytest.mark.parametrize("exponential", [True, False]) | ||
@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) | ||
def test_output(available_device, test_case, k, exponential, ignore_ties, replacement): | ||
device = available_device | ||
y_pred_distribution, y = test_case | ||
|
||
y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) | ||
|
||
y_pred = y_pred.to(device) | ||
y = y.to(device) | ||
|
||
ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) | ||
ndcg.update([y_pred, y]) | ||
result_ignite = ndcg.compute() | ||
|
||
if exponential: | ||
y = 2**y - 1 | ||
|
||
result_sklearn = ndcg_score(y.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties) | ||
|
||
np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) | ||
|
||
|
||
def test_reset(): | ||
y = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) | ||
y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) | ||
ndcg = NDCG() | ||
ndcg.update([y_pred, y]) | ||
ndcg.reset() | ||
|
||
with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."): | ||
ndcg.compute() | ||
|
||
|
||
def _ndcg_sample_scores(y, y_score, k=None, ignore_ties=False): | ||
gain = _dcg_sample_scores(y, y_score, k, ignore_ties=ignore_ties) | ||
normalizing_gain = _dcg_sample_scores(y, y, k, ignore_ties=True) | ||
all_irrelevant = normalizing_gain == 0 | ||
gain[all_irrelevant] = 0 | ||
gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant] | ||
return gain | ||
|
||
|
||
@pytest.mark.parametrize("log_base", [2, 3, 10]) | ||
def test_log_base(log_base): | ||
def ndcg_score_with_log_base(y, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): | ||
gain = _ndcg_sample_scores(y, y_score, k=k, ignore_ties=ignore_ties) | ||
return np.average(gain, weights=sample_weight) | ||
|
||
y = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]) | ||
y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]]) | ||
|
||
ndcg = NDCG(log_base=log_base) | ||
ndcg.update([y_pred, y]) | ||
|
||
result_ignite = ndcg.compute() | ||
result_sklearn = ndcg_score_with_log_base(y.numpy(), y_pred.numpy(), log_base=log_base) | ||
|
||
np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) | ||
|
||
|
||
def test_update(test_case): | ||
y_pred, y = test_case | ||
|
||
y_pred = y_pred | ||
y = y | ||
|
||
y1_pred = torch.multinomial(y_pred, 5, replacement=True) | ||
y1_true = torch.multinomial(y, 5, replacement=True) | ||
|
||
y2_pred = torch.multinomial(y_pred, 5, replacement=True) | ||
y2_true = torch.multinomial(y, 5, replacement=True) | ||
|
||
y_pred_combined = torch.cat((y1_pred, y2_pred)) | ||
y_combined = torch.cat((y1_true, y2_true)) | ||
|
||
ndcg = NDCG() | ||
|
||
ndcg.update([y1_pred, y1_true]) | ||
ndcg.update([y2_pred, y2_true]) | ||
|
||
result_ignite = ndcg.compute() | ||
|
||
result_sklearn = ndcg_score(y_combined.numpy(), y_pred_combined.numpy()) | ||
|
||
np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) | ||
|
||
|
||
@pytest.mark.parametrize("metric_device", ["cpu", "process_device"]) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@pytest.mark.parametrize("num_epochs", [1, 2]) | ||
def test_distrib_integration(distributed, num_epochs, metric_device): | ||
from ignite.engine import Engine | ||
|
||
rank = idist.get_rank() | ||
torch.manual_seed(12 + rank) | ||
n_iters = 5 | ||
batch_size = 8 | ||
device = idist.device() | ||
if metric_device == "process_device": | ||
metric_device = device if device.type != "xla" else "cpu" | ||
|
||
# 10 items | ||
y = torch.rand((n_iters * batch_size, 10)).to(device) | ||
y_preds = torch.rand((n_iters * batch_size, 10)).to(device) | ||
|
||
def update(engine, i): | ||
return ( | ||
y_preds[i * batch_size : (i + 1) * batch_size, ...], | ||
y[i * batch_size : (i + 1) * batch_size, ...], | ||
) | ||
|
||
engine = Engine(update) | ||
NDCG(device=metric_device).attach(engine, "ndcg") | ||
|
||
data = list(range(n_iters)) | ||
engine.run(data=data, max_epochs=num_epochs) | ||
|
||
y_preds = idist.all_gather(y_preds) | ||
y = idist.all_gather(y) | ||
|
||
assert "ndcg" in engine.state.metrics | ||
res = engine.state.metrics["ndcg"] | ||
|
||
true_res = ndcg_score(y.cpu().numpy(), y_preds.cpu().numpy()) | ||
|
||
tol = 1e-3 if device.type == "xla" else 1e-4 # Isn't better to ask `distributed` about backend info? | ||
|
||
assert pytest.approx(res, abs=tol) == true_res | ||
|
||
|
||
@pytest.mark.parametrize("metric_device", [torch.device("cpu"), "process_device"]) | ||
def test_distrib_accumulator_device(distributed, metric_device): | ||
device = idist.device() | ||
if metric_device == "process_device": | ||
metric_device = torch.device(device if device.type != "xla" else "cpu") | ||
|
||
ndcg = NDCG(device=metric_device) | ||
|
||
y_pred = torch.rand((2, 10)).to(device) | ||
y = torch.rand((2, 10)).to(device) | ||
ndcg.update((y_pred, y)) | ||
|
||
dev = ndcg.ndcg.device | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's check here: https://github.com/catalyst-team/catalyst/blob/master/catalyst/metrics/_ndcg.py if there is another way to implement this in a vectorized way, https://github.com/pytorch/ignite/pull/2632/files#r930048810
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vfdev-5 while studying about this, I found out that sklearn is using for loop too. is there any particular reason why this for loop need to be changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is purely perf reasons. For example, computing s1 below will be faster than s2: