diff --git a/mono/flatmap.go b/mono/flatmap.go index 982e7e9..d97b378 100644 --- a/mono/flatmap.go +++ b/mono/flatmap.go @@ -15,30 +15,41 @@ const ( statComplete = 2 ) +type flatMapStat int32 + +const ( + _ flatMapStat = iota + flatMapStatInnerReady + flatMapStatInnerComplete + flatMapStatError + flatMapStatComplete +) + type innerFlatMapSubscriber struct { parent *flatMapSubscriber } func (in *innerFlatMapSubscriber) OnError(err error) { - if atomic.CompareAndSwapInt32(&in.parent.stat, 0, statError) { + if atomic.CompareAndSwapInt32(&in.parent.stat, int32(flatMapStatInnerReady), int32(flatMapStatError)) { in.parent.actual.OnError(err) } } func (in *innerFlatMapSubscriber) OnNext(v Any) { - if atomic.LoadInt32(&in.parent.stat) != 0 { + if atomic.LoadInt32(&in.parent.stat) != int32(flatMapStatInnerReady) { + hooks.Global().OnNextDrop(v) return } in.parent.actual.OnNext(v) in.OnComplete() } -func (in *innerFlatMapSubscriber) OnSubscribe(ctx context.Context, s reactor.Subscription) { +func (in *innerFlatMapSubscriber) OnSubscribe(_ context.Context, s reactor.Subscription) { s.Request(reactor.RequestInfinite) } func (in *innerFlatMapSubscriber) OnComplete() { - if atomic.CompareAndSwapInt32(&in.parent.stat, 0, statComplete) { + if atomic.CompareAndSwapInt32(&in.parent.stat, int32(flatMapStatInnerReady), int32(flatMapStatInnerComplete)) { in.parent.actual.OnComplete() } } @@ -58,13 +69,13 @@ func (p *flatMapSubscriber) Cancel() { } func (p *flatMapSubscriber) OnComplete() { - if atomic.LoadInt32(&p.stat) == statComplete { + if atomic.CompareAndSwapInt32(&p.stat, 0, int32(flatMapStatComplete)) { p.actual.OnComplete() } } func (p *flatMapSubscriber) OnError(err error) { - if !atomic.CompareAndSwapInt32(&p.stat, 0, statError) { + if !atomic.CompareAndSwapInt32(&p.stat, 0, int32(flatMapStatError)) { hooks.Global().OnErrorDrop(err) return } @@ -72,7 +83,7 @@ func (p *flatMapSubscriber) OnError(err error) { } func (p *flatMapSubscriber) OnNext(v Any) { - if atomic.LoadInt32(&p.stat) != 0 { + if !atomic.CompareAndSwapInt32(&p.stat, 0, int32(flatMapStatInnerReady)) { hooks.Global().OnNextDrop(v) return } diff --git a/mono/flatmap_test.go b/mono/flatmap_test.go index d73a4d6..5124bf1 100644 --- a/mono/flatmap_test.go +++ b/mono/flatmap_test.go @@ -3,12 +3,14 @@ package mono_test import ( "context" "errors" + "sync/atomic" "testing" "time" "github.com/golang/mock/gomock" "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/mono" + "github.com/jjeffcaii/reactor-go/scheduler" "github.com/stretchr/testify/assert" ) @@ -70,3 +72,64 @@ func TestFlatMap_MultipleEmit(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, res) } + +func TestFlatMapSubscriber_OnComplete(t *testing.T) { + completes := new(int32) + res, err := mono.Just(1). + FlatMap(func(any reactor.Any) mono.Mono { + return mono.Create(func(ctx context.Context, s mono.Sink) { + s.Success(2) + }).SubscribeOn(scheduler.Parallel()) + }). + DoOnComplete(func() { + atomic.AddInt32(completes, 1) + }). + Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, res, "bad result") + assert.Equal(t, int32(1), atomic.LoadInt32(completes), "completes should be 1") + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + s := mono.NewMockSubscriber(ctrl) + s.EXPECT().OnComplete().Times(1) + s.EXPECT().OnNext(gomock.Eq(2)).Times(1) + s.EXPECT().OnError(gomock.Any()).Times(0) + s.EXPECT().OnSubscribe(gomock.Any(), gomock.Any()).Do(mono.MockRequestInfinite).Times(1) + + mono.Just(1). + FlatMap(func(any reactor.Any) mono.Mono { + return mono.Just(2) + }). + SubscribeWith(context.Background(), s) +} + +func TestFlatMap_EmptySource(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + s := mono.NewMockSubscriber(ctrl) + s.EXPECT().OnComplete().Times(1) + s.EXPECT().OnNext(gomock.Any()).Times(0) + s.EXPECT().OnError(gomock.Any()).Times(0) + s.EXPECT().OnSubscribe(gomock.Any(), gomock.Any()).Do(mono.MockRequestInfinite).Times(1) + + mono.Empty(). + FlatMap(func(any reactor.Any) mono.Mono { + return mono.Just(1) + }). + SubscribeWith(context.Background(), s) + + s = mono.NewMockSubscriber(ctrl) + s.EXPECT().OnComplete().Times(1) + s.EXPECT().OnNext(gomock.Any()).Times(0) + s.EXPECT().OnError(gomock.Any()).Times(0) + s.EXPECT().OnSubscribe(gomock.Any(), gomock.Any()).Do(mono.MockRequestInfinite).Times(1) + + mono.Just(1). + FlatMap(func(value reactor.Any) mono.Mono { + assert.Equal(t, 1, value) + return mono.Empty() + }). + SubscribeWith(context.Background(), s) +} diff --git a/mono/just.go b/mono/just.go index 0fb6190..1631765 100644 --- a/mono/just.go +++ b/mono/just.go @@ -18,7 +18,7 @@ type justSubscriptionPool struct { func (j *justSubscriptionPool) get() *justSubscription { if exist, _ := j.inner.Get().(*justSubscription); exist != nil { - exist.n = 0 + exist.stat = 0 return exist } return &justSubscription{} @@ -30,7 +30,7 @@ func (j *justSubscriptionPool) put(s *justSubscription) { } s.actual = nil s.parent = nil - atomic.StoreInt32(&s.n, math.MinInt32) + atomic.StoreInt32(&s.stat, math.MinInt32) j.inner.Put(s) } @@ -47,21 +47,23 @@ func newMonoJust(v Any) *monoJust { type justSubscription struct { actual reactor.Subscriber parent *monoJust - n int32 + stat int32 } func (j *justSubscription) Request(n int) { + defer globalJustSubscriptionPool.put(j) + if n < 1 { + j.actual.OnError(errors.Errorf("positive request amount required but it was %d", n)) return } - if !atomic.CompareAndSwapInt32(&j.n, 0, statComplete) { + if !atomic.CompareAndSwapInt32(&j.stat, 0, statComplete) { return } defer func() { actual := j.actual - globalJustSubscriptionPool.put(j) rec := recover() if rec == nil { @@ -81,7 +83,7 @@ func (j *justSubscription) Request(n int) { } func (j *justSubscription) Cancel() { - if atomic.CompareAndSwapInt32(&j.n, 0, statCancel) { + if atomic.CompareAndSwapInt32(&j.stat, 0, statCancel) { j.actual.OnError(reactor.ErrSubscribeCancelled) } } diff --git a/mono/just_test.go b/mono/just_test.go index 2f98405..74329c8 100644 --- a/mono/just_test.go +++ b/mono/just_test.go @@ -67,7 +67,6 @@ func TestMonoJust_Request(t *testing.T) { sub := NewMockSubscriber(ctrl) onSubscribe := func(ctx context.Context, su reactor.Subscription) { - su.Request(0) su.Request(1) su.Request(1) } diff --git a/mono/mono.go b/mono/mono.go index 9da6b40..2cccf0a 100644 --- a/mono/mono.go +++ b/mono/mono.go @@ -14,7 +14,7 @@ type ( Disposable = reactor.Disposable ) -type FlatMapper func(reactor.Any) Mono +type FlatMapper func(value reactor.Any) Mono type Combinator func(values ...*reactor.Item) (reactor.Any, error) // Mono is a Reactive Streams Publisher with basic rx operators that completes successfully by emitting an element, or with an error. diff --git a/scheduler/elastic_test.go b/scheduler/elastic_test.go index 93abeee..28205e1 100644 --- a/scheduler/elastic_test.go +++ b/scheduler/elastic_test.go @@ -40,18 +40,19 @@ func TestNewElastic(t *testing.T) { } func TestElasticBounded(t *testing.T) { - const total = 1000 - var wg sync.WaitGroup - wg.Add(total) - start := time.Now() - worker := scheduler.ElasticBounded().Worker() - for range [total]struct{}{} { - err := worker.Do(func() { - time.Sleep(10 * time.Millisecond) - wg.Done() - }) - assert.NoError(t, err) - } - wg.Wait() - assert.Less(t, int64(time.Since(start)), int64(20*time.Millisecond), "bad result") + assert.NotPanics(t, func() { + const total = 1000 + var wg sync.WaitGroup + wg.Add(total) + worker := scheduler.ElasticBounded().Worker() + for range [total]struct{}{} { + err := worker.Do(func() { + time.Sleep(10 * time.Millisecond) + wg.Done() + }) + assert.NoError(t, err) + } + wg.Wait() + }) + }