diff --git a/hooks/hooks.go b/hooks/hooks.go index 47b454a..b61048e 100644 --- a/hooks/hooks.go +++ b/hooks/hooks.go @@ -9,37 +9,37 @@ import ( var globalHooks = &Hooks{} type Hooks struct { + sync.RWMutex nextDrops []reactor.FnOnDiscard errDrops []reactor.FnOnError - locker sync.RWMutex } func (p *Hooks) OnNextDrop(t reactor.Any) { - p.locker.RLock() + p.RLock() + defer p.RUnlock() for _, fn := range p.nextDrops { fn(t) } - p.locker.RUnlock() } func (p *Hooks) OnErrorDrop(e error) { - p.locker.RLock() + p.RLock() + defer p.RUnlock() for _, fn := range p.errDrops { fn(e) } - p.locker.RUnlock() } func (p *Hooks) registerOnNextDrop(fn reactor.FnOnDiscard) { - p.locker.Lock() + p.Lock() + defer p.Unlock() p.nextDrops = append(p.nextDrops, fn) - p.locker.Unlock() } func (p *Hooks) registerOnErrorDrop(fn reactor.FnOnError) { - p.locker.Lock() + p.Lock() + defer p.Unlock() p.errDrops = append(p.errDrops, fn) - p.locker.Unlock() } func Global() *Hooks { diff --git a/mono/mono.go b/mono/mono.go index 3a4b7d2..e64e29f 100644 --- a/mono/mono.go +++ b/mono/mono.go @@ -27,6 +27,7 @@ type Mono interface { DoOnDiscard(reactor.FnOnDiscard) Mono SwitchIfEmpty(alternative Mono) Mono DelayElement(delay time.Duration) Mono + Timeout(timeout time.Duration) Mono } type Processor interface { diff --git a/mono/mono_test.go b/mono/mono_test.go index ac5fcf6..ed2b6d2 100644 --- a/mono/mono_test.go +++ b/mono/mono_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go/hooks" "github.com/jjeffcaii/reactor-go/mono" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/stretchr/testify/assert" @@ -377,3 +378,45 @@ func TestMapWithError(t *testing.T) { Block(context.Background()) assert.Equal(t, fakeErr, err, "should return error") } + +func TestTimeout(t *testing.T) { + fakeErr := errors.New("fake err") + dropped := new(int32) + errorDropped := new(int32) + hooks.OnNextDrop(func(v reactor.Any) { + atomic.AddInt32(dropped, 1) + }) + hooks.OnErrorDrop(func(e error) { + atomic.AddInt32(errorDropped, 1) + }) + _, err := mono. + Create(func(ctx context.Context, s mono.Sink) { + time.Sleep(200 * time.Millisecond) + s.Success("hello") + }). + Timeout(100 * time.Millisecond). + Block(context.Background()) + assert.Error(t, err) + time.Sleep(200 * time.Millisecond) + assert.Equal(t, int32(1), atomic.LoadInt32(dropped)) + + value, err := mono.Just("hello").Timeout(100 * time.Millisecond).Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "hello", value) + + _, err = mono.Error(fakeErr).Timeout(100 * time.Millisecond).Block(context.Background()) + assert.Equal(t, fakeErr, err) + + _, err = mono.Create(func(ctx context.Context, s mono.Sink) { + time.Sleep(100 * time.Millisecond) + s.Error(err) + }).Timeout(500 * time.Millisecond).Block(context.Background()) + assert.Equal(t, fakeErr, err) + + _, err = mono.Create(func(ctx context.Context, s mono.Sink) { + time.Sleep(100 * time.Millisecond) + s.Error(err) + }).Timeout(50 * time.Millisecond).Block(context.Background()) + assert.True(t, reactor.IsCancelledError(err)) + assert.Equal(t, int32(1), atomic.LoadInt32(errorDropped)) +} diff --git a/mono/mono_timeout.go b/mono/mono_timeout.go new file mode 100644 index 0000000..9838f9d --- /dev/null +++ b/mono/mono_timeout.go @@ -0,0 +1,80 @@ +package mono + +import ( + "context" + "time" + + "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go/hooks" + "github.com/jjeffcaii/reactor-go/internal" +) + +type monoTimeout struct { + source reactor.RawPublisher + timeout time.Duration +} + +type timeoutSubscriber struct { + actual reactor.Subscriber + timeout time.Duration + done chan struct{} +} + +func (t *timeoutSubscriber) OnComplete() { + select { + case <-t.done: + default: + close(t.done) + t.actual.OnComplete() + } +} + +func (t *timeoutSubscriber) OnError(err error) { + select { + case <-t.done: + hooks.Global().OnErrorDrop(err) + default: + close(t.done) + t.actual.OnError(err) + } +} + +func (t *timeoutSubscriber) OnNext(any reactor.Any) { + select { + case <-t.done: + hooks.Global().OnNextDrop(any) + default: + t.actual.OnNext(any) + } +} + +func (t *timeoutSubscriber) OnSubscribe(ctx context.Context, subscription reactor.Subscription) { + timer := time.NewTimer(t.timeout) + go func() { + defer timer.Stop() + select { + case <-timer.C: + t.OnError(reactor.ErrSubscribeCancelled) + case <-t.done: + } + }() + t.actual.OnSubscribe(ctx, subscription) +} + +func (m *monoTimeout) SubscribeWith(ctx context.Context, subscriber reactor.Subscriber) { + subscriber = internal.ExtractRawSubscriber(subscriber) + ts := &timeoutSubscriber{ + actual: subscriber, + timeout: m.timeout, + done: make(chan struct{}), + } + subscriber = internal.NewCoreSubscriber(ts) + m.source.SubscribeWith(ctx, subscriber) +} + +func newMonoTimeout(source reactor.RawPublisher, timeout time.Duration) *monoTimeout { + return &monoTimeout{ + source: source, + timeout: timeout, + } +} diff --git a/mono/wrapper.go b/mono/wrapper.go index 1572b63..eaccb2a 100644 --- a/mono/wrapper.go +++ b/mono/wrapper.go @@ -71,6 +71,13 @@ func (p wrapper) DelayElement(delay time.Duration) Mono { return wrap(newMonoDelayElement(p.RawPublisher, delay, scheduler.Elastic())) } +func (p wrapper) Timeout(timeout time.Duration) Mono { + if timeout <= 0 { + return p + } + return wrap(newMonoTimeout(p.RawPublisher, timeout)) +} + func (p wrapper) Block(ctx context.Context) (value Any, err error) { done := make(chan struct{}) p.