diff --git a/lib/websockex/conn.ex b/lib/websockex/conn.ex index f33077b..881a104 100644 --- a/lib/websockex/conn.ex +++ b/lib/websockex/conn.ex @@ -23,7 +23,8 @@ defmodule WebSockex.Conn do cacerts: nil, insecure: true, resp_headers: [], - ssl_options: nil + ssl_options: nil, + socket_options: nil @type socket :: :gen_tcp.socket() | :ssl.sslsocket() @type header :: {field :: String.t(), value :: String.t()} @@ -47,6 +48,7 @@ defmodule WebSockex.Conn do - `:socket_recv_timeout` - Timeout in ms for receiving a HTTP response header from socket, default #{@socket_recv_timeout_default} ms. - `:ssl_options` - extra options for an SSL connection + - `:socket_options` - extra options for the TCP part of the connection [public_key]: http://erlang.org/doc/apps/public_key/using_public_key.html """ @@ -57,6 +59,7 @@ defmodule WebSockex.Conn do | {:socket_connect_timeout, non_neg_integer} | {:socket_recv_timeout, non_neg_integer} | {:ssl_options, [:ssl.tls_client_option()]} + | {:socket_options, [:gen_tcp.option()]} @type t :: %__MODULE__{ conn_mod: :gen_tcp | :ssl, @@ -95,7 +98,8 @@ defmodule WebSockex.Conn do socket_connect_timeout: Keyword.get(opts, :socket_connect_timeout, @socket_connect_timeout_default), socket_recv_timeout: Keyword.get(opts, :socket_recv_timeout, @socket_recv_timeout_default), - ssl_options: Keyword.get(opts, :ssl_options, nil) + ssl_options: Keyword.get(opts, :ssl_options, nil), + socket_options: Keyword.get(opts, :socket_options, nil), } end @@ -121,7 +125,7 @@ defmodule WebSockex.Conn do %URI{host: host, port: port, scheme: protocol} when is_nil(host) when is_nil(port) - when not (protocol in ["ws", "wss", "http", "https"]) -> + when protocol not in ["ws", "wss", "http", "https"] -> {:error, %WebSockex.URLError{url: url}} %URI{path: nil} = uri -> @@ -132,6 +136,24 @@ defmodule WebSockex.Conn do end end + @doc """ + Parses a URI host + Host can be "x.y.z.t" or "some.name.domain". If "x.y.z.t", the function + will return a valid :inet.ip_address() which __MODULE__.open_socket + accepts. This will prevent extra DNS operations which can time out + in some contexts + """ + @spec parse_host(String.t()) :: charlist() | :inet.ip_address() + def parse_host(host) do + host + |> to_charlist() + |> :inet.parse_address() + |> then(fn + {:error, :einval} -> to_charlist(host) + {:ok, addr} -> addr + end) + end + @doc """ Sends data using the `conn_mod` module. """ @@ -151,9 +173,9 @@ defmodule WebSockex.Conn do def open_socket(%{conn_mod: :gen_tcp} = conn) do case :gen_tcp.connect( - String.to_charlist(conn.host), + parse_host(conn.host), conn.port, - [:binary, active: false, packet: 0], + socket_connection_options(conn), conn.socket_connect_timeout ) do {:ok, socket} -> @@ -166,7 +188,7 @@ defmodule WebSockex.Conn do def open_socket(%{conn_mod: :ssl} = conn) do case :ssl.connect( - String.to_charlist(conn.host), + parse_host(conn.host), conn.port, ssl_connection_options(conn), conn.socket_connect_timeout @@ -317,6 +339,25 @@ defmodule WebSockex.Conn do end end + defp minimal_socket_connection_options() do + [ + mode: :binary, + active: false, + packet: 0 + ] + end + + + defp socket_connection_options(%{socket_options: socket_options}) when not is_nil(socket_options) do + minimal_socket_connection_options() + |> Keyword.merge(socket_options) + end + + defp socket_connection_options(%{socket_options: socket_options}) do + minimal_socket_connection_options() + end + + # Crazy SSL Stuff (It will be normal SSL stuff when I figure out Erlang's ssl) defp ssl_connection_options(%{ssl_options: ssl_options}) when not is_nil(ssl_options) do diff --git a/test/websockex/conn_test.exs b/test/websockex/conn_test.exs index bdf8687..f513149 100644 --- a/test/websockex/conn_test.exs +++ b/test/websockex/conn_test.exs @@ -7,7 +7,6 @@ defmodule WebSockex.ConnTest do on_exit(fn -> WebSockex.TestServer.shutdown(server_ref) end) uri = URI.parse(url) - conn = WebSockex.Conn.new(uri) {:ok, conn} = WebSockex.Conn.open_socket(conn) @@ -15,86 +14,97 @@ defmodule WebSockex.ConnTest do [url: url, uri: uri, conn: conn] end - test "new" do - tcp_conn = %WebSockex.Conn{ - host: "localhost", - port: 80, - path: "/ws", - query: nil, - conn_mod: :gen_tcp, - transport: :tcp, - extra_headers: [{"Pineapple", "Cake"}], - socket: nil, - socket_connect_timeout: 6000, - socket_recv_timeout: 5000 - } - - ssl_conn = %WebSockex.Conn{ - host: "localhost", - port: 443, - path: "/ws", - query: nil, - conn_mod: :ssl, - transport: :ssl, - extra_headers: [{"Pineapple", "Cake"}], - socket: nil, - socket_connect_timeout: 6000, - socket_recv_timeout: 5000 - } - - regular_url = "ws://localhost/ws" - regular_uri = URI.parse(regular_url) - - regular_opts = [ - extra_headers: [{"Pineapple", "Cake"}], - socket_connect_timeout: 123, - socket_recv_timeout: 456 - ] - - assert WebSockex.Conn.new(regular_uri, regular_opts) == %{ - tcp_conn - | socket_connect_timeout: 123, - socket_recv_timeout: 456 - } - - assert WebSockex.Conn.new(regular_url, regular_opts) == - WebSockex.Conn.new(regular_uri, regular_opts) - - conn_opts = [extra_headers: [{"Pineapple", "Cake"}]] - - ssl_url = "wss://localhost/ws" - ssl_uri = URI.parse(ssl_url) - assert WebSockex.Conn.new(ssl_uri, conn_opts) == ssl_conn - assert WebSockex.Conn.new(ssl_url, conn_opts) == WebSockex.Conn.new(ssl_uri, conn_opts) - - http_url = "http://localhost/ws" - http_uri = URI.parse(http_url) - assert WebSockex.Conn.new(http_uri, conn_opts) == tcp_conn - assert WebSockex.Conn.new(http_url, conn_opts) == WebSockex.Conn.new(http_uri, conn_opts) - - https_url = "https://localhost/ws" - https_uri = URI.parse(https_url) - assert WebSockex.Conn.new(https_uri, conn_opts) == ssl_conn - assert WebSockex.Conn.new(https_url, conn_opts) == WebSockex.Conn.new(https_uri, conn_opts) - - llama_url = "llama://localhost/ws" - llama_conn = URI.parse(llama_url) - - assert WebSockex.Conn.new(llama_conn, conn_opts) == - %WebSockex.Conn{ - host: "localhost", - port: nil, - path: "/ws", - query: nil, - conn_mod: nil, - transport: nil, - extra_headers: [{"Pineapple", "Cake"}], - socket: nil, - socket_connect_timeout: 6000, - socket_recv_timeout: 5000 + for host <- ["localhost", "127.0.0.1"] do + test "new, with host #{host}" do + localhost = unquote(host) + + localhost_or_addr = + case WebSockex.Conn.parse_host(localhost) do + addr when is_tuple(addr) -> addr + other -> to_string(other) + end + + tcp_conn = %WebSockex.Conn{ + host: localhost_or_addr, + port: 80, + path: "/ws", + query: nil, + conn_mod: :gen_tcp, + transport: :tcp, + extra_headers: [{"Pineapple", "Cake"}], + socket: nil, + socket_connect_timeout: 6000, + socket_recv_timeout: 5000 + } + + ssl_conn = %WebSockex.Conn{ + host: localhost_or_addr, + port: 443, + path: "/ws", + query: nil, + conn_mod: :ssl, + transport: :ssl, + extra_headers: [{"Pineapple", "Cake"}], + socket: nil, + socket_connect_timeout: 6000, + socket_recv_timeout: 5000 + } + + regular_url = "ws://" <> localhost <> "/ws" + regular_uri = URI.parse(regular_url) + + regular_opts = [ + extra_headers: [{"Pineapple", "Cake"}], + socket_connect_timeout: 123, + socket_recv_timeout: 456 + ] + + assert WebSockex.Conn.new(regular_uri, regular_opts) == %{ + tcp_conn + | socket_connect_timeout: 123, + socket_recv_timeout: 456, + host: localhost } - assert {:error, %WebSockex.URLError{}} = WebSockex.Conn.new(llama_url, conn_opts) + assert WebSockex.Conn.new(regular_url, regular_opts) == + WebSockex.Conn.new(regular_uri, regular_opts) + + conn_opts = [extra_headers: [{"Pineapple", "Cake"}]] + + ssl_url = "wss://" <> localhost <> "/ws" + ssl_uri = URI.parse(ssl_url) + assert WebSockex.Conn.new(ssl_uri, conn_opts) == %{ssl_conn | host: localhost} + assert WebSockex.Conn.new(ssl_url, conn_opts) == WebSockex.Conn.new(ssl_uri, conn_opts) + + http_url = "http://" <> localhost <> "/ws" + http_uri = URI.parse(http_url) + assert WebSockex.Conn.new(http_uri, conn_opts) == %{tcp_conn | host: localhost} + assert WebSockex.Conn.new(http_url, conn_opts) == WebSockex.Conn.new(http_uri, conn_opts) + + https_url = "https://" <> localhost <> "/ws" + https_uri = URI.parse(https_url) + assert WebSockex.Conn.new(https_uri, conn_opts) == %{ssl_conn | host: localhost} + assert WebSockex.Conn.new(https_url, conn_opts) == WebSockex.Conn.new(https_uri, conn_opts) + + llama_url = "llama://" <> localhost <> "/ws" + llama_conn = URI.parse(llama_url) + + assert WebSockex.Conn.new(llama_conn, conn_opts) == + %WebSockex.Conn{ + host: localhost, + port: nil, + path: "/ws", + query: nil, + conn_mod: nil, + transport: nil, + extra_headers: [{"Pineapple", "Cake"}], + socket: nil, + socket_connect_timeout: 6000, + socket_recv_timeout: 5000 + } + + assert {:error, %WebSockex.URLError{}} = WebSockex.Conn.new(llama_url, conn_opts) + end end test "parse_url" do @@ -117,6 +127,12 @@ defmodule WebSockex.ConnTest do assert WebSockex.Conn.parse_url(pathless_url) == {:ok, %{URI.parse(pathless_url) | path: "/"}} end + test "parse_host" do + assert WebSockex.Conn.parse_host("strawberry.cake") == 'strawberry.cake' + assert WebSockex.Conn.parse_host("1.2.3.4") == {1, 2, 3, 4} + assert WebSockex.Conn.parse_host("a.b.c.d") == 'a.b.c.d' + end + test "open_socket", context do %{host: host, port: port, path: path} = context.uri