Skip to content

Commit

Permalink
Fix Batcher iterator break when return_last_incomplete_batch and item…
Browse files Browse the repository at this point in the history
…s.is_empty (#2654) (#2655)
  • Loading branch information
hhllhhyyds authored Dec 24, 2024
1 parent 1be6b09 commit 11aa30b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions candle-datasets/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !items.is_empty() {
break;
}
return None;
Expand All @@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
ys.push(y)
}
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break;
}
return None;
Expand All @@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !items.is_empty() {
break;
}
return None;
Expand All @@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu
}
Some(Err(err)) => errs.push(err),
None => {
if self.return_last_incomplete_batch {
if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break;
}
return None;
Expand Down

0 comments on commit 11aa30b

Please sign in to comment.