diff --git a/internal/subscribers/block_subscriber.go b/internal/subscribers/block_subscriber.go index 223a9e8..75f05ac 100644 --- a/internal/subscribers/block_subscriber.go +++ b/internal/subscribers/block_subscriber.go @@ -2,66 +2,104 @@ package subscribers import ( "context" + "math" + "sync" + "sync/atomic" "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/hooks" ) -type blockSubscriber struct { - done chan struct{} - c chan<- reactor.Item +var globalBlockSubscriberPool blockSubscriberPool + +func BorrowBlockSubscriber() *BlockSubscriber { + return globalBlockSubscriberPool.get() +} + +func ReturnBlockSubscriber(s *BlockSubscriber) { + globalBlockSubscriberPool.put(s) +} + +type blockSubscriberPool struct { + inner sync.Pool +} + +func (bp *blockSubscriberPool) get() *BlockSubscriber { + if exist, _ := bp.inner.Get().(*BlockSubscriber); exist != nil { + atomic.StoreInt32(&exist.done, 0) + return exist + } + return &BlockSubscriber{ + doneChan: make(chan struct{}, 1), + } } -func NewBlockSubscriber(done chan struct{}, c chan reactor.Item) reactor.Subscriber { - return blockSubscriber{ - done: done, - c: c, +func (bp *blockSubscriberPool) put(s *BlockSubscriber) { + if s == nil { + return } + s.Reset() + bp.inner.Put(s) } -func (b blockSubscriber) OnComplete() { - select { - case <-b.done: - default: - close(b.done) +type BlockSubscriber struct { + reactor.Item + doneChan chan struct{} + ctxChan chan struct{} + done int32 +} + +func (b *BlockSubscriber) Reset() { + b.V = nil + b.E = nil + b.ctxChan = nil + atomic.StoreInt32(&b.done, math.MinInt32) +} + +func (b *BlockSubscriber) Done() <-chan struct{} { + return b.doneChan +} + +func (b *BlockSubscriber) OnComplete() { + if atomic.CompareAndSwapInt32(&b.done, 0, 1) { + b.finish() } } -func (b blockSubscriber) OnError(err error) { - select { - case <-b.done: +func (b *BlockSubscriber) OnError(err error) { + if !atomic.CompareAndSwapInt32(&b.done, 0, 1) { hooks.Global().OnErrorDrop(err) - default: - select { - case b.c <- reactor.Item{E: err}: - default: - hooks.Global().OnErrorDrop(err) - } - close(b.done) + return + } + b.E = err + b.finish() +} + +func (b *BlockSubscriber) finish() { + if b.ctxChan != nil { + close(b.ctxChan) } + b.doneChan <- struct{}{} } -func (b blockSubscriber) OnNext(any reactor.Any) { - select { - case <-b.done: +func (b *BlockSubscriber) OnNext(any reactor.Any) { + if atomic.LoadInt32(&b.done) != 0 || b.V != nil || b.E != nil { hooks.Global().OnNextDrop(any) - default: - select { - case b.c <- reactor.Item{V: any}: - default: - hooks.Global().OnNextDrop(any) - } + return } + b.V = any } -func (b blockSubscriber) OnSubscribe(ctx context.Context, subscription reactor.Subscription) { +func (b *BlockSubscriber) OnSubscribe(ctx context.Context, subscription reactor.Subscription) { // workaround: watch context if ctx != context.Background() && ctx != context.TODO() { + ctxChan := make(chan struct{}) + b.ctxChan = ctxChan go func() { select { case <-ctx.Done(): b.OnError(reactor.NewContextError(ctx.Err())) - case <-b.done: + case <-ctxChan: } }() } diff --git a/internal/subscribers/block_subscriber_test.go b/internal/subscribers/block_subscriber_test.go new file mode 100644 index 0000000..8d930b2 --- /dev/null +++ b/internal/subscribers/block_subscriber_test.go @@ -0,0 +1,86 @@ +package subscribers + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go/internal" + "github.com/stretchr/testify/assert" +) + +func TestBlockSubscriber(t *testing.T) { + fakeErr := errors.New("fake error") + + // complete test + s := BorrowBlockSubscriber() + go func() { + s.OnNext(1) + s.OnComplete() + }() + + <-s.Done() + + assert.NoError(t, s.E, "should not return error") + assert.Equal(t, 1, s.V, "bad result") + ReturnBlockSubscriber(s) + + // error test + s = BorrowBlockSubscriber() + s.OnError(fakeErr) + // omit + s.OnNext(2) + s.OnError(fakeErr) + s.OnComplete() + + <-s.Done() + + assert.Equal(t, fakeErr, s.E, "should be fake error") + assert.Nil(t, s.V) + ReturnBlockSubscriber(s) + + // empty test + s = BorrowBlockSubscriber() + s.OnComplete() + // omit + s.OnNext(2) + s.OnError(fakeErr) + + <-s.Done() + + assert.NoError(t, s.E, "should not return error") + assert.Nil(t, s.V) + ReturnBlockSubscriber(s) +} + +func TestReturnBlockSubscriber(t *testing.T) { + assert.NotPanics(t, func() { + ReturnBlockSubscriber(nil) + }) +} + +func TestBlockSubscriber_OnSubscribe(t *testing.T) { + s := BorrowBlockSubscriber() + s.OnSubscribe(context.Background(), internal.EmptySubscription) + ReturnBlockSubscriber(s) + + s = BorrowBlockSubscriber() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + s.OnSubscribe(ctx, internal.EmptySubscription) + <-s.Done() + assert.Error(t, s.E, "should return error") + assert.True(t, reactor.IsCancelledError(s.E), "should be cancelled error") + ReturnBlockSubscriber(s) + + s = BorrowBlockSubscriber() + ctx, cancel = context.WithCancel(context.Background()) + s.OnSubscribe(ctx, internal.EmptySubscription) + s.OnComplete() + time.Sleep(10 * time.Millisecond) + cancel() + + <-s.Done() +} diff --git a/mono/create.go b/mono/create.go index 95810a7..de4a16e 100644 --- a/mono/create.go +++ b/mono/create.go @@ -29,8 +29,8 @@ func (p *sinkPool) put(s *sink) { if s == nil { return } - s.actual = nil atomic.StoreInt32(&s.stat, math.MinInt32) + s.actual = nil p.inner.Put(s) } diff --git a/mono/wrapper_utils.go b/mono/wrapper_utils.go index 6573a01..77e0789 100644 --- a/mono/wrapper_utils.go +++ b/mono/wrapper_utils.go @@ -32,22 +32,16 @@ func IsSubscribeAsync(m Mono) bool { } func block(ctx context.Context, publisher reactor.RawPublisher) (Any, error) { - done := make(chan struct{}) - c := make(chan reactor.Item, 1) - b := subscribers.NewBlockSubscriber(done, c) - publisher.SubscribeWith(ctx, b) - <-done - defer close(c) - - select { - case result := <-c: - if result.E != nil { - return nil, result.E - } - return result.V, nil - default: - return nil, nil + s := subscribers.BorrowBlockSubscriber() + defer subscribers.ReturnBlockSubscriber(s) + + publisher.SubscribeWith(ctx, s) + <-s.Done() + + if s.E != nil { + return nil, s.E } + return s.V, nil } func unpackRawPublisher(source Mono) reactor.RawPublisher {