Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add write timeout #2843

Merged
merged 10 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions lib/mongo/protocol/message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,7 @@ def self.deserialize(io,
# timeout option. For compatibility with whoever might call this
# method with some other IO-like object, pass options only when they
# are not empty.
read_options = {}
if timeout = options[:socket_timeout]
read_options[:timeout] = timeout
end
read_options = options.slice(:timeout, :socket_timeout)

if read_options.empty?
chunk = io.read(16)
Expand Down
7 changes: 4 additions & 3 deletions lib/mongo/server/connection_base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def deliver(message, context, options = {})
result = nil
begin
result = add_server_diagnostics do
socket.write(buffer.to_s)
socket.write(buffer.to_s, timeout: context.remaining_timeout_sec)
if message.replyable?
Protocol::Message.deserialize(socket, max_message_size, message.request_id, options)
check_timeout!(context)
Protocol::Message.deserialize(socket, max_message_size, message.request_id, options.merge(timeout: context.remaining_timeout_sec))
else
nil
end
Expand Down Expand Up @@ -288,7 +289,7 @@ def check_timeout!(context)
return if context.remaining_timeout_sec.nil?
time_to_execute = context.remaining_timeout_sec - server.minimum_round_trip_time
if time_to_execute <= 0
raise Mongo::Error:TimeoutError
raise Mongo::Error::TimeoutError
end
end
end
Expand Down
177 changes: 153 additions & 24 deletions lib/mongo/socket.rb
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,24 @@ def gets(*args)
# socket.read(4096)
#
# @param [ Integer ] length The number of bytes to read.
# @param [ Numeric ] timeout The timeout to use for each chunk read.
# @param [ Numeric ] socket_timeout The timeout to use for each chunk read,
# mutually exclusive to +timeout+.
# @param [ Numeric ] timeout The total timeout to the whole read operation,
# mutually exclusive to +socket_timeout+.
#
# @raise [ Mongo::SocketError ] If not all data is returned.
#
# @return [ Object ] The data from the socket.
#
# @since 2.0.0
def read(length, timeout: nil)
map_exceptions do
data = read_from_socket(length, timeout: timeout)
unless (data.length > 0 || length == 0)
raise IOError, "Expected to read > 0 bytes but read 0 bytes"
end
while data.length < length
chunk = read_from_socket(length - data.length, timeout: timeout)
unless (chunk.length > 0 || length == 0)
raise IOError, "Expected to read > 0 bytes but read 0 bytes"
end
data << chunk
end
data
def read(length, socket_timeout: nil, timeout: nil)
if !socket_timeout.nil? && !timeout.nil?
raise ArgumentError, 'Both timeout and socket_timeout cannot be set'
end
if !socket_timeout.nil? || timeout.nil?
read_without_timeout(length, socket_timeout)
else
read_with_timeout(length, timeout)
end
end

Expand All @@ -233,15 +230,16 @@ def readbyte
# Writes data to the socket instance.
#
# @param [ Array<Object> ] args The data to be written.
# @param [ Numeric ] timeout The total timeout to the whole write operation.
#
# @return [ Integer ] The length of bytes written to the socket.
#
# @raise [ Error::SocketError | Error::SocketTimeoutError ] When there is a network error during the write.
#
# @since 2.0.0
def write(*args)
def write(*args, timeout: nil)
map_exceptions do
do_write(*args)
do_write(*args, timeout: timeout)
end
end

Expand All @@ -265,18 +263,76 @@ def connectable?

private

def read_from_socket(length, timeout: nil)
# Reads the +length+ bytes from the socket, the read operation duration is
# limited to +timeout+ second.
#
# @param [ Integer ] length The number of bytes to read.
# @param [ Numeric ] timeout The total timeout to the whole read operation.
#
# @return [ Object ] The data from the socket.
def read_with_timeout(length, timeout)
deadline = Utils.monotonic_time + timeout
map_exceptions do
String.new.tap do |data|
while data.length < length
socket_timeout = deadline - Utils.monotonic_time
if socket_timeout <= 0
raise Mongo::Error::TimeoutError
end
chunk = read_from_socket(length - data.length, socket_timeout: socket_timeout, csot: true)
unless chunk.length > 0
raise IOError, "Expected to read > 0 bytes but read 0 bytes"
end
data << chunk
end
end
end
end

# Reads the +length+ bytes from the socket. The read operation may involve
# multiple socket reads, each read is limited to +timeout+ second,
# if the parameter is provided.
#
# @param [ Integer ] length The number of bytes to read.
# @param [ Numeric ] socket_timeout The timeout to use for each chunk read.
#
# @return [ Object ] The data from the socket.
def read_without_timeout(length, socket_timeout = nil)
map_exceptions do
String.new.tap do |data|
while data.length < length
chunk = read_from_socket(length - data.length, socket_timeout: socket_timeout)
unless chunk.length > 0
raise IOError, "Expected to read > 0 bytes but read 0 bytes"
end
data << chunk
end
end
end
end


# Reads the +length+ bytes from the socket. The read operation may involve
# multiple socket reads, each read is limited to +timeout+ second,
# if the parameter is provided.
#
# @param [ Integer ] length The number of bytes to read.
# @param [ Numeric ] :socket_timeout The timeout to use for each chunk read.
# @param [ true | false ] :csot Whether the CSOT timeout is set for the operation.
#
# @return [ Object ] The data from the socket.
def read_from_socket(length, socket_timeout: nil, csot: false)
# Just in case
if length == 0
return ''.force_encoding('BINARY')
end

_timeout = timeout || self.timeout
_timeout = socket_timeout || self.timeout
if _timeout
if _timeout > 0
deadline = Utils.monotonic_time + _timeout
elsif _timeout < 0
raise Errno::ETIMEDOUT, "Negative timeout #{_timeout} given to socket"
raise_timeout_error!("Negative timeout #{_timeout} given to socket", csot)
end
end

Expand Down Expand Up @@ -331,7 +387,7 @@ def read_from_socket(length, timeout: nil)
if deadline
select_timeout = deadline - Utils.monotonic_time
if select_timeout <= 0
raise Errno::ETIMEDOUT, "Took more than #{_timeout} seconds to receive data"
raise_timeout_error!("Took more than #{_timeout} seconds to receive data", csot)
end
end
pipe = options[:pipe]
Expand Down Expand Up @@ -373,11 +429,11 @@ def read_from_socket(length, timeout: nil)
if deadline
select_timeout = deadline - Utils.monotonic_time
if select_timeout <= 0
raise Errno::ETIMEDOUT, "Took more than #{_timeout} seconds to receive data"
raise_timeout_error!("Took more than #{_timeout} seconds to receive data", csot)
end
end
elsif rv.nil?
raise Errno::ETIMEDOUT, "Took more than #{_timeout} seconds to receive data (select call timed out)"
raise_timeout_error!("Took more than #{_timeout} seconds to receive data (select call timed out)", csot)
end
retry
end
Expand All @@ -402,9 +458,23 @@ def read_buffer_size
# sholud map exceptions.
#
# @param [ Array<Object> ] args The data to be written.
# @param [ Numeric ] :timeout The total timeout to the whole write operation.
#
# @return [ Integer ] The length of bytes written to the socket.
def do_write(*args, timeout: nil)
if timeout.nil?
write_without_timeout(*args)
else
write_with_timeout(*args, timeout: timeout)
end
end

# Writes data to to the socket.
#
# @param [ Array<Object> ] args The data to be written.
#
# @return [ Integer ] The length of bytes written to the socket.
def do_write(*args)
def write_without_timeout(*args)
# This method used to forward arguments to @socket.write in a
# single call like so:
#
Expand All @@ -428,6 +498,57 @@ def do_write(*args)
end
end

# Writes data to to the socket, the write duration is limited to +timeout+.
#
# @param [ Array<Object> ] args The data to be written.
# @param [ Numeric ] :timeout The total timeout to the whole write operation.
#
# @return [ Integer ] The length of bytes written to the socket.
def write_with_timeout(*args, timeout:)
raise ArgumentError, 'timeout cannot be nil' if timeout.nil?
raise_timeout_error!("Negative timeout #{timeout} given to socket", true) if timeout < 0

written = 0
args.each do |buf|
buf = buf.to_s
i = 0
while i < buf.length
chunk = buf[i...(i + WRITE_CHUNK_SIZE)]
written += write_chunk(chunk, timeout)
i += WRITE_CHUNK_SIZE
end
end
written
end

def write_chunk(chunk, timeout)
deadline = Utils.monotonic_time + timeout
written = 0
begin
written += @socket.write_nonblock(chunk[written..-1])
rescue IO::WaitWritable, Errno::EINTR
select_timeout = deadline - Utils.monotonic_time
rv = Kernel.select(nil, [@socket], nil, select_timeout)
if BSON::Environment.jruby?
# Ignore the return value of Kernel.select.
# On JRuby, select appears to return nil prior to timeout expiration
# (apparently due to a EAGAIN) which then causes us to fail the read
# even though we could have retried it.
# Check the deadline ourselves.
if deadline
select_timeout = deadline - Utils.monotonic_time
if select_timeout <= 0
raise_timeout_error!("Took more than #{timeout} seconds to receive data", true)
end
end
elsif rv.nil?
raise_timeout_error!("Took more than #{timeout} seconds to receive data (select call timed out)", true)
end
retry
end
written
end

def unix_socket?(sock)
defined?(UNIXSocket) && sock.is_a?(UNIXSocket)
end
Expand Down Expand Up @@ -482,5 +603,13 @@ def map_exceptions
def human_address
raise NotImplementedError
end

def raise_timeout_error!(message = nil, csot = false)
if csot
raise Mongo::Error::TimeoutError
else
raise Errno::ETIMEDOUT, message
end
end
end
end
Loading