Skip to content

Commit

Permalink
Merge pull request #17 from m-lab/sbs-sandbox
Browse files Browse the repository at this point in the history
Add prototype of next generation NDT protocol
  • Loading branch information
stephen-soltesz authored Aug 9, 2018
2 parents abd9d6c + cd64b9b commit 75fa28e
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 0 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ before_script:
- npm install --verbose ws minimist

script:
- go get -v ./cmd/ndt-cloud-client
- go test -v -coverprofile=ndt.cov github.com/m-lab/ndt-cloud
- $GOPATH/bin/goveralls -coverprofile=ndt.cov -service=travis-ci
47 changes: 47 additions & 0 deletions cmd/ndt-cloud-client/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"flag"
"os"
"os/signal"
"runtime"
"syscall"

"github.com/apex/log"
"github.com/m-lab/ndt-cloud/ndt7"
)

var disableTLS = flag.Bool("disable-tls", false, "Whether to disable TLS")
var duration = flag.Int("duration", 10, "Desired duration")
var hostname = flag.String("hostname", "localhost", "Host to connect to")
var port = flag.String("port", "3001", "Port to connect to")
var skipTLSVerify = flag.Bool("skip-tls-verify", false, "Skip TLS verify")

func main() {
flag.Parse()
settings := ndt7.Settings{}
settings.Hostname = *hostname
settings.InsecureNoTLS = *disableTLS
settings.InsecureSkipTLSVerify = *skipTLSVerify
settings.Port = *port
settings.Duration = *duration
clnt := ndt7.NewClient(settings)
ch := make(chan interface{}, 1)
defer close(ch)
sigs := make(chan os.Signal, 1)
defer close(sigs)
if runtime.GOOS != "windows" {
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
log.Warn("Got interrupt signal")
ch <- false
log.Warn("Delivered interrupt signal")
}()
}
err := clnt.Download()
if err != nil {
log.WithError(err).Warn("clnt.Download() failed")
os.Exit(1)
}
}
3 changes: 3 additions & 0 deletions ndt-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"syscall"
"time"

"github.com/m-lab/ndt-cloud/ndt7"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -642,6 +643,8 @@ func main() {
log.Fatal(http.ListenAndServe(*fMetricsAddr, mux))
}()

http.Handle(ndt7.DownloadURLPath, ndt7.DownloadHandler{})

http.HandleFunc("/", defaultHandler)
http.Handle("/static/", http.StripPrefix("/static", http.FileServer(http.Dir("html"))))
http.Handle("/ndt_protocol",
Expand Down
7 changes: 7 additions & 0 deletions ndt-server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"runtime"
"testing"

"github.com/m-lab/ndt-cloud/ndt7"

pipe "gopkg.in/m-lab/pipe.v3"
)

Expand All @@ -29,6 +31,7 @@ func Test_NDTe2e(t *testing.T) {

// Start a test server using the NdtServer as the entry point.
mux := http.NewServeMux()
mux.Handle(ndt7.DownloadURLPath, ndt7.DownloadHandler{})
mux.Handle("/ndt_protocol", http.HandlerFunc(NdtServer))
ts := httptest.NewTLSServer(mux)
defer ts.Close()
Expand Down Expand Up @@ -68,6 +71,10 @@ func Test_NDTe2e(t *testing.T) {
" --protocol=wss --acceptinvalidcerts --abort-c2s-early --tests=22 & " +
"sleep 25",
},
{
name: "Test the NDT7 protocol",
cmd: "ndt-cloud-client -skip-tls-verify -port " + u.Port(),
},
}

for _, testCmd := range tests {
Expand Down
114 changes: 114 additions & 0 deletions ndt7/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package ndt7

import (
"crypto/tls"
"encoding/json"
"net"
"net/http"
"net/url"
"strconv"
"time"

"github.com/apex/log"
"github.com/gorilla/websocket"
)

// Settings contains client settings. All settings are optional except for
// the Hostname, which cannot be autoconfigured at the moment.
type Settings struct {
// This structure embeds options defined in the spec.
Options
// InsecureSkipTLSVerify can be used to disable certificate verification.
InsecureSkipTLSVerify bool `json:"skip_tls_verify"`
// InsecureNoTLS can be used to force using cleartext.
InsecureNoTLS bool `json:"no_tls"`
// Hostname is the hostname of the NDT7 server to connect to.
Hostname string `json:"hostname"`
// Port is the port of the NDT7 server to connect to.
Port string `json:"port"`
}

// Client is a NDT7 client.
type Client struct {
dialer websocket.Dialer
url url.URL
}

// NewClient creates a new client.
func NewClient(settings Settings) Client {
cl := Client{}
cl.dialer.HandshakeTimeout = defaultTimeout
if settings.InsecureSkipTLSVerify {
config := tls.Config{InsecureSkipVerify: true}
cl.dialer.TLSClientConfig = &config
log.Warn("Disabling TLS cerificate verification (INSECURE!)")
}
if settings.InsecureNoTLS {
log.Warn("Using plain text WebSocket (INSECURE!)")
cl.url.Scheme = "ws"
} else {
cl.url.Scheme = "wss"
}
if settings.Port != "" {
ip := net.ParseIP(settings.Hostname)
if ip == nil || len(ip) == 4 {
cl.url.Host = settings.Hostname
cl.url.Host += ":"
cl.url.Host += settings.Port
} else if len(ip) == 16 {
cl.url.Host = "["
cl.url.Host += settings.Hostname
cl.url.Host += "]:"
cl.url.Host += settings.Port
} else {
panic("IP address that is neither 4 nor 16 bytes long")
}
} else {
cl.url.Host = settings.Hostname
}
query := cl.url.Query()
if settings.Duration > 0 {
query.Add("duration", strconv.Itoa(settings.Duration))
}
cl.url.RawQuery = query.Encode()
return cl
}

// defaultTimeout is the default value of the I/O timeout.
const defaultTimeout = 1 * time.Second

// Download runs a NDT7 download test.
func (cl Client) Download() error {
log.Info("Creating a WebSocket connection")
cl.url.Path = DownloadURLPath
headers := http.Header{}
headers.Add("Sec-WebSocket-Protocol", SecWebSocketProtocol)
conn, _, err := cl.dialer.Dial(cl.url.String(), headers)
if err != nil {
return err
}
conn.SetReadLimit(MinMaxMessageSize)
defer conn.Close()
log.Info("Starting download")
for {
conn.SetReadDeadline(time.Now().Add(defaultTimeout))
mtype, mdata, err := conn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure) {
return err
}
break
}
if mtype == websocket.TextMessage {
// Unmarshaling to verify that this message is correct JSON
measurement := Measurement{}
err := json.Unmarshal(mdata, &measurement)
if err != nil {
return err
}
log.Infof("Server measurement: %s", mdata)
}
}
log.Info("Download complete")
return nil
}
100 changes: 100 additions & 0 deletions ndt7/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package ndt7

import (
"crypto/rand"
"net/http"
"strconv"
"time"

"github.com/apex/log"
"github.com/gorilla/websocket"
)

// defaultDuration is the default duration of a subtest in nanoseconds.
const defaultDuration = 10 * time.Second

// maxDuration is the maximum duration of a subtest in seconds
const maxDuration = 30

// DownloadHandler handles a download subtest from the server side.
type DownloadHandler struct {
Upgrader websocket.Upgrader
}

// Handle handles the download subtest.
func (dl DownloadHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
log.Debug("Processing query string")
duration := defaultDuration
{
s := request.URL.Query().Get("duration")
if s != "" {
value, err := strconv.Atoi(s)
if err != nil || value < 0 || value > maxDuration {
log.Warn("The duration option has an invalid value")
writer.Header().Set("Connection", "Close")
writer.WriteHeader(http.StatusBadRequest)
return
}
duration = time.Second * time.Duration(value)
}
}
log.Debug("Upgrading to WebSockets")
if request.Header.Get("Sec-WebSocket-Protocol") != SecWebSocketProtocol {
log.Warn("Missing Sec-WebSocket-Protocol in request")
writer.Header().Set("Connection", "Close")
writer.WriteHeader(http.StatusBadRequest)
return
}
headers := http.Header{}
headers.Add("Sec-WebSocket-Protocol", SecWebSocketProtocol)
conn, err := dl.Upgrader.Upgrade(writer, request, headers)
if err != nil {
log.WithError(err).Warn("upgrader.Upgrade() failed")
return
}
conn.SetReadLimit(MinMaxMessageSize)
defer conn.Close()
log.Debug("Generating random buffer")
const bufferSize = 1 << 13
data := make([]byte, bufferSize)
rand.Read(data)
buffer, err := websocket.NewPreparedMessage(websocket.BinaryMessage, data)
if err != nil {
log.WithError(err).Warn("websocket.NewPreparedMessage() failed")
return
}
log.Debug("Start sending data to client")
ticker := time.NewTicker(MinMeasurementInterval)
defer ticker.Stop()
t0 := time.Now()
count := int64(0)
for running := true; running; {
select {
case t := <-ticker.C:
// TODO(bassosimone): here we should also include tcp_info data
measurement := Measurement{
Elapsed: t.Sub(t0).Nanoseconds(),
NumBytes: count,
}
conn.SetWriteDeadline(time.Now().Add(defaultTimeout))
if err := conn.WriteJSON(&measurement); err != nil {
log.WithError(err).Warn("Cannot send measurement message")
return
}
default: // Not ticking, just send more data
if time.Now().Sub(t0) >= duration {
running = false
break
}
conn.SetWriteDeadline(time.Now().Add(defaultTimeout))
if err := conn.WritePreparedMessage(buffer); err != nil {
log.WithError(err).Warn("cannot send data message")
return
}
count += bufferSize
}
}
log.Debug("Closing the WebSocket connection")
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(
websocket.CloseNormalClosure, ""), time.Now().Add(defaultTimeout))
}
Loading

0 comments on commit 75fa28e

Please sign in to comment.