diff --git a/cmd/marker/main.go b/cmd/marker/main.go index 581037a27..78ea0e46d 100644 --- a/cmd/marker/main.go +++ b/cmd/marker/main.go @@ -17,6 +17,7 @@ package main import ( "flag" "fmt" + "net" "os" "reflect" "strings" @@ -29,6 +30,12 @@ import ( "github.com/k8snetworkplumbingwg/ovs-cni/pkg/marker" ) +const ( + UnixSocketType = "unix" + TcpSocketType = "tcp" + SocketConnectionTimeout = time.Minute +) + func main() { nodeName := flag.String("node-name", "", "name of kubernetes node") ovsSocket := flag.String("ovs-socket", "", "address of openvswitch database connection") @@ -54,38 +61,16 @@ func main() { if *ovsSocket == "" { glog.Fatal("ovs-socket must be set") } - - var socketType, path string - ovsSocketTokens := strings.Split(*ovsSocket, ":") - if len(ovsSocketTokens) < 2 { - /* - * ovsSocket should consist of comma separated socket type and socket - * detail. If no socket type is specified, it is assumed to be a unix - * domain socket, for backwards compatibility. - */ - socketType = "unix" - path = *ovsSocket - } else { - socketType = ovsSocketTokens[0] - path = ovsSocketTokens[1] + socketType, address, err := parseOvsSocket(ovsSocket) + if err != nil { + glog.Fatalf("Failed to parse ovs socket: %v", err) } - - if socketType == "unix" { - for { - _, err := os.Stat(path) - if err == nil { - glog.Info("Found the OVS socket") - break - } else if os.IsNotExist(err) { - glog.Infof("Given ovs-socket %q was not found, waiting for the socket to appear", path) - time.Sleep(time.Minute) - } else { - glog.Fatalf("Failed opening the OVS socket with: %v", err) - } - } + if err = validateOvsSocketConnection(socketType, address); err != nil { + glog.Fatal("Failed to connect to ovs: %v", err) } + endpoint := fmt.Sprintf("%s:%s", socketType, address) - markerApp, err := marker.NewMarker(*nodeName, socketType+":"+path) + markerApp, err := marker.NewMarker(*nodeName, endpoint) if err != nil { glog.Fatalf("Failed to create a new marker object: %v", err) } @@ -137,3 +122,90 @@ func keepAlive(healthCheckFile string, healthCheckInterval int) { }, time.Duration(healthCheckInterval)*time.Second) } + +/* +takes an OVS socket string and returns the socket +type, address, and any parsing error. +*/ +func parseOvsSocket(ovsSocket *string) (string, string, error) { + var socketType, address string + ovsSocketTokens := strings.Split(*ovsSocket, ":") + if len(ovsSocketTokens) < 2 { + /* + * ovsSocket should consist of comma separated socket type and socket + * detail. If no socket type is specified, it is assumed to be a unix + * domain socket, for backwards compatibility. + */ + socketType = UnixSocketType + address = *ovsSocket + } else { + socketType = ovsSocketTokens[0] + if socketType == TcpSocketType { + if len(ovsSocketTokens) != 3 { + return "", "", fmt.Errorf("failed to parse OVS %s socket, must be in this format %s::", socketType, socketType) + } + address = fmt.Sprintf("%s:%s", ovsSocketTokens[1], ovsSocketTokens[2]) + } else { + // unix socket + socketType = UnixSocketType + address = ovsSocketTokens[1] + } + } + return socketType, address, nil +} + +func validateOvsSocketConnection(socketType, address string) error { + validator, err := getOvsSocketValidator(socketType) + if err != nil { + return err + } + return validator(address) +} + +func getOvsSocketValidator(socketType string) (func(string) error, error) { + switch socketType { + case UnixSocketType: + return validateOvsUnixConnection, nil + case TcpSocketType: + return validateOvsTcpConnection, nil + default: + return nil, fmt.Errorf("unsupported ovs socket type: %s", socketType) + } +} + +func validateOvsUnixConnection(address string) error { + for { + _, err := os.Stat(address) + if err == nil { + glog.Info("Found the OVS socket") + break + } else if os.IsNotExist(err) { + glog.Infof("Given ovs-socket %q was not found, waiting for the socket to appear", address) + time.Sleep(SocketConnectionTimeout) + } else { + return fmt.Errorf("failed opening the OVS socket with: %v", err) + } + } + return nil +} + +func validateOvsTcpConnection(address string) error { + conn, err := net.DialTimeout(TcpSocketType, address, SocketConnectionTimeout) + if err == nil { + glog.Info("Successfully connected to TCP socket") + conn.Close() + return nil + } + + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return fmt.Errorf("connection to %s timed out", address) + } else if opErr, ok := err.(*net.OpError); ok { + if opErr.Op == "dial" { + return fmt.Errorf("connection to %s failed: %v", address, err) + } else { + return fmt.Errorf("unexpected error when connecting to %s: %v", address, err) + } + } else { + return fmt.Errorf("unexpected error when connecting to %s: %v", address, err) + } +}