diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a28d5b7b..387e23b8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * `sslcertfile` - path to an SSL certificate file, * `sslcafile` - path to a trusted certificate authorities (CA) file, * `sslciphers` - colon-separated list of SSL cipher suites the connection. +- `tt aeon connect`: add support to connect Aeon database. ### Changed diff --git a/cli/aeon/client.go b/cli/aeon/client.go new file mode 100644 index 000000000..ecac1f01c --- /dev/null +++ b/cli/aeon/client.go @@ -0,0 +1,199 @@ +package aeon + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/apex/log" + + "github.com/tarantool/go-prompt" + "github.com/tarantool/tt/cli/aeon/cmd" + "github.com/tarantool/tt/cli/aeon/pb" + "github.com/tarantool/tt/cli/connector" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +// Client structure with parameters for gRPC connection to Aeon. +type Client struct { + title string + conn *grpc.ClientConn + client pb.SQLServiceClient +} + +func makeAddress(ctx cmd.ConnectCtx) string { + if ctx.Network == connector.UnixNetwork { + if strings.HasPrefix(ctx.Address, "@") { + return "unix-abstract:" + (ctx.Address)[1:] + } + return "unix:" + ctx.Address + } + return ctx.Address +} + +func getCertificate(args cmd.Ssl) (tls.Certificate, error) { + if args.CertFile == "" && args.KeyFile == "" { + return tls.Certificate{}, nil + } + tls_cert, err := tls.LoadX509KeyPair(args.CertFile, args.KeyFile) + if err != nil { + return tls_cert, fmt.Errorf("could not load client key pair: %w", err) + } + return tls_cert, nil +} + +func getTlsConfig(args cmd.Ssl) (*tls.Config, error) { + if args.CaFile == "" { + return &tls.Config{ + ClientAuth: tls.NoClientCert, + }, nil + } + + ca, err := os.ReadFile(args.CaFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file: %w", err) + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(ca) { + return nil, errors.New("failed to append CA data") + } + cert, err := getCertificate(args) + if err != nil { + return nil, fmt.Errorf("failed get certificate: %w", err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + RootCAs: certPool, + }, nil +} + +func getDialOpts(ctx cmd.ConnectCtx) (grpc.DialOption, error) { + var creds credentials.TransportCredentials + if ctx.Transport == cmd.TransportSsl { + config, err := getTlsConfig(ctx.Ssl) + if err != nil { + return nil, fmt.Errorf("not tls config: %w", err) + } + creds = credentials.NewTLS(config) + } else { + creds = insecure.NewCredentials() + } + return grpc.WithTransportCredentials(creds), nil +} + +// NewAeonHandler create new grpc connection to Aeon server. +func NewAeonHandler(ctx cmd.ConnectCtx) (*Client, error) { + c := Client{title: ctx.Address} + target := makeAddress(ctx) + // var err error + opt, err := getDialOpts(ctx) + if err != nil { + return nil, fmt.Errorf("%w", err) + } + c.conn, err = grpc.NewClient(target, opt) + if err != nil { + return nil, fmt.Errorf("fail to dial: %w", err) + } + if err := c.ping(); err == nil { + log.Infof("Aeon responses at %q", target) + } else { + return nil, fmt.Errorf("can't ping to Aeon at %q: %w", target, err) + } + + c.client = pb.NewSQLServiceClient(c.conn) + return &c, nil +} + +func (c *Client) ping() error { + log.Infof("Start ping aeon server") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + diag := pb.NewDiagServiceClient(c.conn) + _, err := diag.Ping(ctx, &pb.PingRequest{}) + if err != nil { + log.Warnf("Aeon ping %s", err) + } + return err +} + +// Title implements console.Handler interface. +func (c *Client) Title() string { + return c.title +} + +// Validate implements console.Handler interface. +func (c *Client) Validate(input string) bool { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + check, err := c.client.SQLCheck(ctx, &pb.SQLRequest{Query: input}) + if err != nil { + log.Warnf("Aeon validate %s\nFor request: %q", err, input) + return false + } + + return check.Status == pb.SQLCheckStatus_SQL_QUERY_VALID +} + +// Execute implements console.Handler interface. +func (c *Client) Execute(input string) any { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + resp, err := c.client.SQL(ctx, &pb.SQLRequest{Query: input}) + if err != nil { + return err + } + return parseSQLResponse(resp) +} + +// Stop implements console.Handler interface. +func (c *Client) Close() { + c.conn.Close() +} + +// Complete implements console.Handler interface. +func (c *Client) Complete(input prompt.Document) []prompt.Suggest { + // TODO: waiting until there is support from Aeon side. + return nil +} + +// parseSQLResponse returns result as table in map. +// Where keys is name of columns. And body is array of values. +// On any issue return an error. +func parseSQLResponse(resp *pb.SQLResponse) any { + if resp.Error != nil { + return ResultError{resp.Error} + } + if resp.TupleFormat == nil { + return ResultType{} + } + res := ResultType{ + names: make([]string, len(resp.TupleFormat.Names)), + rows: make([]ResultRow, len(resp.Tuples)), + } + for i, n := range resp.TupleFormat.Names { + res.names[i] = n + res.rows[i] = make([]any, 0, len(resp.TupleFormat.Names)) + } + + for r, row := range resp.Tuples { + for _, v := range row.Fields { + val, err := decodeValue(v) + if err != nil { + return fmt.Errorf("tuple %d can't decode %s: %w", r, v.String(), err) + } + res.rows[r] = append(res.rows[r], val) + } + } + return res +} diff --git a/cli/aeon/cmd/connect.go b/cli/aeon/cmd/connect.go index cdb3b80a2..d84272ccd 100644 --- a/cli/aeon/cmd/connect.go +++ b/cli/aeon/cmd/connect.go @@ -16,4 +16,8 @@ type ConnectCtx struct { Ssl Ssl // Transport is a connection mode. Transport Transport + // Network is kind of transport layer. + Network string + // Address is a connection URL, unix socket address and etc. + Address string } diff --git a/cli/aeon/decode.go b/cli/aeon/decode.go new file mode 100644 index 000000000..240de1cf0 --- /dev/null +++ b/cli/aeon/decode.go @@ -0,0 +1,100 @@ +package aeon + +import ( + "fmt" + "time" + + "github.com/google/uuid" + "github.com/tarantool/go-tarantool/v2/datetime" + "github.com/tarantool/go-tarantool/v2/decimal" + "github.com/tarantool/tt/cli/aeon/pb" +) + +// decodeValue convert a value obtained from protobuf into a value that can be used as an +// argument to Tarantool functions. +// +// Copy from https://github.com/tarantool/aeon/blob/master/aeon/grpc/server/pb/decode.go +func decodeValue(val *pb.Value) (any, error) { + switch casted := val.Kind.(type) { + case *pb.Value_UnsignedValue: + return val.GetUnsignedValue(), nil + case *pb.Value_StringValue: + return val.GetStringValue(), nil + case *pb.Value_NumberValue: + return val.GetNumberValue(), nil + case *pb.Value_IntegerValue: + return val.GetIntegerValue(), nil + case *pb.Value_BooleanValue: + return val.GetBooleanValue(), nil + case *pb.Value_VarbinaryValue: + return val.GetVarbinaryValue(), nil + case *pb.Value_DecimalValue: + decStr := val.GetDecimalValue() + res, err := decimal.MakeDecimalFromString(decStr) + if err != nil { + return nil, err + } + return res, nil + case *pb.Value_UuidValue: + uuidStr := val.GetUuidValue() + res, err := uuid.Parse(uuidStr) + if err != nil { + return nil, err + } + return res, nil + case *pb.Value_DatetimeValue: + sec := casted.DatetimeValue.Seconds + nsec := casted.DatetimeValue.Nsec + t := time.Unix(sec, nsec) + if len(casted.DatetimeValue.Location) > 0 { + locStr := casted.DatetimeValue.Location + loc, err := time.LoadLocation(locStr) + if err != nil { + return nil, err + } + t = t.In(loc) + } + res, err := datetime.MakeDatetime(t) + if err != nil { + return nil, err + } + return res, nil + case *pb.Value_IntervalValue: + res := datetime.Interval{ + Year: casted.IntervalValue.Year, + Month: casted.IntervalValue.Month, + Week: casted.IntervalValue.Week, + Day: casted.IntervalValue.Day, + Hour: casted.IntervalValue.Hour, + Min: casted.IntervalValue.Min, + Sec: casted.IntervalValue.Sec, + Nsec: casted.IntervalValue.Nsec, + Adjust: datetime.Adjust(casted.IntervalValue.Adjust)} + return res, nil + case *pb.Value_ArrayValue: + array := val.GetArrayValue() + res := make([]any, len(array.Fields)) + for k, v := range array.Fields { + field, err := decodeValue(v) + if err != nil { + return nil, err + } + res[k] = field + } + return res, nil + case *pb.Value_MapValue: + res := make(map[any]any, len(casted.MapValue.Fields)) + for k, v := range casted.MapValue.Fields { + item, err := decodeValue(v) + if err != nil { + return nil, err + } + res[k] = item + } + return res, nil + case *pb.Value_NullValue: + return nil, nil + default: + return nil, fmt.Errorf("unsupported type for value") + } +} diff --git a/cli/aeon/results.go b/cli/aeon/results.go new file mode 100644 index 000000000..49a5ba6c5 --- /dev/null +++ b/cli/aeon/results.go @@ -0,0 +1,53 @@ +package aeon + +import ( + "fmt" + + "github.com/tarantool/tt/cli/aeon/pb" + "github.com/tarantool/tt/cli/console" + "github.com/tarantool/tt/cli/formatter" +) + +// ResultRow keeps values for one table row. +type ResultRow []any + +// ResultType is a custom type to format output with console.Formatter interface. +type ResultType struct { + names []string + rows []ResultRow +} + +// ResultError wraps pb.Error to implement console.Formatter interface. +type ResultError struct { + *pb.Error +} + +// asYaml prepare results for formatter.MakeOutput. +func (r ResultType) asYaml() string { + yaml := "---\n" + for _, row := range r.rows { + mark := "-" + for i, v := range row { + n := r.names[i] + yaml += fmt.Sprintf("%s %s: %v\n", mark, n, v) + mark = " " + } + } + return yaml +} + +// Format produce formatted string according required console.Format settings. +func (r ResultType) Format(f console.Format) string { + if len(r.names) == 0 { + return "" + } + output, err := formatter.MakeOutput(f.Mode, r.asYaml(), f.Opts) + if err != nil { + return fmt.Sprintf("can't format output: %s;\nResults:\n%v", err, r) + } + return output +} + +func (e *ResultError) Format(_ console.Format) string { + return fmt.Sprintf("---\nError: %s\n%q", e.Name, e.Msg) +} diff --git a/cli/aeon/results_export_test.go b/cli/aeon/results_export_test.go new file mode 100644 index 000000000..a0c93f60a --- /dev/null +++ b/cli/aeon/results_export_test.go @@ -0,0 +1,16 @@ +package aeon + +import "github.com/tarantool/tt/cli/aeon/pb" + +func NewResultType(names []string, rows []ResultRow) ResultType { + return ResultType{ + names: names, + rows: rows, + } +} + +func NewResultError(name string, msg string) ResultError { + return ResultError{&pb.Error{ + Name: name, + Msg: msg}} +} diff --git a/cli/aeon/results_test.go b/cli/aeon/results_test.go new file mode 100644 index 000000000..8e0c009c2 --- /dev/null +++ b/cli/aeon/results_test.go @@ -0,0 +1,125 @@ +package aeon_test + +import ( + "testing" + + "github.com/tarantool/tt/cli/aeon" + "github.com/tarantool/tt/cli/console" + "github.com/tarantool/tt/cli/formatter" +) + +func TestResultType_Format(t *testing.T) { + type result struct { + names []string + rows []aeon.ResultRow + } + + tests := []struct { + name string + data result + f console.Format + want string + }{ + { + name: "Table with string values", + data: result{ + names: []string{"field1", "field2"}, + rows: []aeon.ResultRow{ + {"value11", "value12"}, + {"value21", "value22"}, + }, + }, + f: console.FormatAsTable(), + want: `+---------+---------+ +| field1 | field2 | ++---------+---------+ +| value11 | value12 | ++---------+---------+ +| value21 | value22 | ++---------+---------+ +`, + }, + { + name: "Table with no string values", + data: result{ + names: []string{"field1", "field2"}, + rows: []aeon.ResultRow{ + {[]bool{true, false}, 123}, + {nil, 456.78}, + }, + }, + f: console.FormatAsTable(), + want: `+----------------+--------+ +| field1 | field2 | ++----------------+--------+ +| ["true false"] | 123 | ++----------------+--------+ +| | 456.78 | ++----------------+--------+ +`, + }, + { + name: "Format as Yaml", + data: result{ + names: []string{"field1", "field2"}, + rows: []aeon.ResultRow{ + {"value11", "value12"}, + {true, 3.14}, + }, + }, + f: console.Format{ + Mode: formatter.YamlFormat, + }, + want: `--- +- field1: value11 + field2: value12 +- field1: true + field2: 3.14 + +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := aeon.NewResultType(tt.data.names, tt.data.rows) + if got := r.Format(tt.f); got != tt.want { + t.Errorf("ResultType.Format() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestResultError_Format(t *testing.T) { + type fields struct { + name string + msg string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Name and Message", + fields: fields{"Name of error", "Long error message string."}, + want: `--- +Error: Name of error +"Long error message string."`, + }, + { + name: "No message", + fields: fields{"Name of error", ""}, + want: `--- +Error: Name of error +""`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := aeon.NewResultError(tt.fields.name, tt.fields.msg) + if got := e.Format(console.Format{}); got != tt.want { + t.Errorf("ResultError.Format() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cli/cmd/aeon.go b/cli/cmd/aeon.go index 0e41f0d4d..31f9df5e9 100644 --- a/cli/cmd/aeon.go +++ b/cli/cmd/aeon.go @@ -5,24 +5,26 @@ import ( "fmt" "github.com/spf13/cobra" - aeon "github.com/tarantool/tt/cli/aeon/cmd" + aeon "github.com/tarantool/tt/cli/aeon" + aeoncmd "github.com/tarantool/tt/cli/aeon/cmd" "github.com/tarantool/tt/cli/cmdcontext" + "github.com/tarantool/tt/cli/console" "github.com/tarantool/tt/cli/modules" "github.com/tarantool/tt/cli/util" libconnect "github.com/tarantool/tt/lib/connect" ) -var aeonConnectCtx = aeon.ConnectCtx{ - Transport: aeon.TransportPlain, +var connectCtx = aeoncmd.ConnectCtx{ + Transport: aeoncmd.TransportPlain, } func newAeonConnectCmd() *cobra.Command { var aeonCmd = &cobra.Command{ Use: "connect URI", Short: "Connect to the aeon instance", - Long: "Connect to the aeon instance.\n\n" + - libconnect.EnvCredentialsHelp + "\n\n" + - `tt aeon connect user:pass@localhost:3013`, + Long: `Connect to the aeon instance. +tt aeon connect localhost:3013 +tt aeon connect unix://`, PreRunE: func(cmd *cobra.Command, args []string) error { err := aeonConnectValidateArgs(cmd, args) util.HandleCmdErr(cmd, err) @@ -34,16 +36,17 @@ func newAeonConnectCmd() *cobra.Command { internalAeonConnect, args) util.HandleCmdErr(cmd, err) }, + Args: cobra.ExactArgs(1), } - aeonCmd.Flags().StringVar(&aeonConnectCtx.Ssl.KeyFile, "sslkeyfile", "", + aeonCmd.Flags().StringVar(&connectCtx.Ssl.KeyFile, "sslkeyfile", "", "path to a private SSL key file") - aeonCmd.Flags().StringVar(&aeonConnectCtx.Ssl.CertFile, "sslcertfile", "", + aeonCmd.Flags().StringVar(&connectCtx.Ssl.CertFile, "sslcertfile", "", "path to a SSL certificate file") - aeonCmd.Flags().StringVar(&aeonConnectCtx.Ssl.CaFile, "sslcafile", "", + aeonCmd.Flags().StringVar(&connectCtx.Ssl.CaFile, "sslcafile", "", "path to a trusted certificate authorities (CA) file") - aeonCmd.Flags().Var(&aeonConnectCtx.Transport, "transport", - fmt.Sprintf("allowed %s", aeon.ListValidTransports())) + aeonCmd.Flags().Var(&connectCtx.Transport, "transport", + fmt.Sprintf("allowed %s", aeoncmd.ListValidTransports())) aeonCmd.RegisterFlagCompletionFunc("transport", aeonTransportCompletion) return aeonCmd @@ -51,8 +54,8 @@ func newAeonConnectCmd() *cobra.Command { func aeonTransportCompletion(cmd *cobra.Command, args []string, toComplete string) ( []string, cobra.ShellCompDirective) { - suggest := make([]string, 0, len(aeon.ValidTransport)) - for k, v := range aeon.ValidTransport { + suggest := make([]string, 0, len(aeoncmd.ValidTransport)) + for k, v := range aeoncmd.ValidTransport { suggest = append(suggest, string(k)+"\t"+v) } return suggest, cobra.ShellCompDirectiveDefault @@ -71,36 +74,59 @@ func NewAeonCmd() *cobra.Command { } func aeonConnectValidateArgs(cmd *cobra.Command, args []string) error { - if !cmd.Flags().Changed("transport") && (aeonConnectCtx.Ssl.KeyFile != "" || - aeonConnectCtx.Ssl.CertFile != "" || aeonConnectCtx.Ssl.CaFile != "") { - aeonConnectCtx.Transport = aeon.TransportSsl + connectCtx.Network, connectCtx.Address = libconnect.ParseBaseURI(args[0]) + + if !cmd.Flags().Changed("transport") && (connectCtx.Ssl.KeyFile != "" || + connectCtx.Ssl.CertFile != "" || connectCtx.Ssl.CaFile != "") { + connectCtx.Transport = aeoncmd.TransportSsl } checkFile := func(path string) bool { return path == "" || util.IsRegularFile(path) } - if aeonConnectCtx.Transport != aeon.TransportPlain { + if connectCtx.Transport != aeoncmd.TransportPlain { if cmd.Flags().Changed("sslkeyfile") != cmd.Flags().Changed("sslcertfile") { return errors.New("files Key and Cert must be specified both") } - if !checkFile(aeonConnectCtx.Ssl.KeyFile) { + if !checkFile(connectCtx.Ssl.KeyFile) { return fmt.Errorf("not valid path to a private SSL key file=%q", - aeonConnectCtx.Ssl.KeyFile) + connectCtx.Ssl.KeyFile) } - if !checkFile(aeonConnectCtx.Ssl.CertFile) { + if !checkFile(connectCtx.Ssl.CertFile) { return fmt.Errorf("not valid path to an SSL certificate file=%q", - aeonConnectCtx.Ssl.CertFile) + connectCtx.Ssl.CertFile) } - if !checkFile(aeonConnectCtx.Ssl.CaFile) { + if !checkFile(connectCtx.Ssl.CaFile) { return fmt.Errorf("not valid path to trusted certificate authorities (CA) file=%q", - aeonConnectCtx.Ssl.CaFile) + connectCtx.Ssl.CaFile) } } return nil } func internalAeonConnect(cmdCtx *cmdcontext.CmdCtx, args []string) error { + hist, err := console.DefaultHistoryFile() + if err != nil { + return fmt.Errorf("can't open history file: %w", err) + } + handler, err := aeon.NewAeonHandler(connectCtx) + if err != nil { + return err + } + opts := console.ConsoleOpts{ + Handler: handler, + History: &hist, + Format: console.FormatAsTable(), + } + c, err := console.NewConsole(opts) + if err != nil { + return fmt.Errorf("can't create aeon console: %w", err) + } + err = c.Run() + if err != nil { + return fmt.Errorf("can't start aeon console: %w", err) + } return nil } diff --git a/cli/console/console.go b/cli/console/console.go new file mode 100644 index 000000000..cf3e9fb4f --- /dev/null +++ b/cli/console/console.go @@ -0,0 +1,263 @@ +package console + +import ( + "bufio" + "errors" + "fmt" + "os" + "strings" + "syscall" + "unicode" + + "github.com/apex/log" + "golang.org/x/term" + + "github.com/tarantool/go-prompt" +) + +const ( + maxLivePrefixIndent = 15 + // See https://github.com/tarantool/tarantool/blob/b53cb2aeceedc39f356ceca30bd0087ee8de7c16/src/box/lua/console.c#L265 + tarantoolWordSeparators = "\t\r\n !\"#$%&'()*+,-/;<=>?@[\\]^`{|}~" +) + +var ( + controlLeftBytes = []byte{0x1b, 0x62} + controlRightBytes = []byte{0x1b, 0x66} +) + +// ConsoleOpts collection console options to create new console. +type ConsoleOpts struct { + // Handler is the implementation of command processor. + Handler Handler + // History if specified than save input commands with it. + History HistoryKeeper + // Format options set how to formatting result. + Format Format +} + +// Console implementation of active console handler. +type Console struct { + impl ConsoleOpts + internal Handler // internal Handler execute console's additional backslash commands. + input string + quit bool + prefix string + livePrefixEnabled bool + livePrefix string + delimiter string + prompt *prompt.Prompt +} + +// NewConsole creates a new console connected to the tarantool instance. +func NewConsole(opts ConsoleOpts) (Console, error) { + if opts.Handler == nil { + return Console{quit: true}, errors.New("no handler for commands has been set") + } + c := Console{ + impl: opts, + quit: false, + } + c.setPrefix() + return c, nil +} + +func (c *Console) runOnPipe() error { + pipe := bufio.NewScanner(os.Stdin) + log.Infof("Processing piped input") + for pipe.Scan() { + line := pipe.Text() + c.execute(line) + } + + err := pipe.Err() + if err == nil { + log.Info("EOF on pipe") + } else { + log.Warnf("Error on pipe %v", err) + } + return err +} + +// Run starts console. +func (c *Console) Run() error { + if c.quit { + return errors.New("can't run on stopped console") + } + if !term.IsTerminal(syscall.Stdin) { + return c.runOnPipe() + } + + log.Infof("Connected to %s\n", c.title()) + c.prompt = prompt.New( + c.execute, + c.complete, + c.getPromptOptions()..., + ) + c.prompt.Run() + + return nil +} + +// Close frees up resources used by the console. +func (c *Console) Close() { + c.impl.Handler.Close() + if c.impl.History != nil { + c.impl.History.Close() + } +} + +// executeEmbeddedCommand try process additional backslash commands. +func (c *Console) executeEmbeddedCommand(in string) bool { + if c.input == "" && c.internal != nil { + if c.internal.Execute(in) != nil { + if c.quit { + c.Close() + log.Infof("Quit from the console") + os.Exit(0) + } + return true + } + } + return false +} + +// cleanupDelimiter checks if the input statement ends with the string `c.delimiter`. +// If yes, it removes it. Returns true if the delimiter has been removed. +func (c *Console) cleanupDelimiter() bool { + if c.delimiter == "" { + return true + } + no_space := strings.TrimRightFunc(c.input, func(r rune) bool { + return unicode.IsSpace(r) + }) + no_delim := strings.TrimSuffix(no_space, c.delimiter) + if len(no_space) > len(no_delim) { + c.input = no_delim + return true + } + return false +} + +// addStmt adds a new part of the statement. +// It returns true if the statement is already completed. +func (c *Console) addStmt(part string) bool { + if c.input == "" { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + c.input = part + } + } else { + c.input += "\n" + part + } + + has_delim := c.cleanupDelimiter() + c.livePrefixEnabled = !(has_delim && c.impl.Handler.Validate(c.input)) + return !c.livePrefixEnabled +} + +// execute called from prompt to process input. +func (c *Console) execute(in string) { + if c.executeEmbeddedCommand(in) || !c.addStmt(in) { + return + } + + trimmed := strings.TrimSpace(c.input) + if c.impl.History != nil { + c.impl.History.AppendCommand(trimmed) + } + + if c.prompt != nil { + if err := c.prompt.PushToHistory(trimmed); err != nil { + log.Debug(err.Error()) + } + } + + results := c.impl.Handler.Execute(c.input) + if results == nil { + c.Close() + log.Infof("Connection closed") + os.Exit(0) + } + if err := c.impl.Format.print(results); err != nil { + log.Errorf("Unable to format output: %s", err) + log.Infof("Source results:\n%v", results) + } + + c.input = "" + c.livePrefixEnabled = false +} + +// title return console's title. +func (c *Console) title() string { + return c.impl.Handler.Title() +} + +// complete provide prompt suggestions. +func (c *Console) complete(input prompt.Document) []prompt.Suggest { + if c.input == "" && c.internal != nil { + return c.internal.Complete(input) + } + return c.impl.Handler.Complete(input) +} + +// setPrefix adjust console prefix string. +func (c *Console) setPrefix() { + c.prefix = fmt.Sprintf("%s> ", c.title()) + + livePrefixIndent := len(c.title()) + if livePrefixIndent > maxLivePrefixIndent { + livePrefixIndent = maxLivePrefixIndent + } + + c.livePrefix = fmt.Sprintf("%s> ", strings.Repeat(" ", livePrefixIndent)) +} + +// getPromptOptions prepare option for prompt. +func (c *Console) getPromptOptions() []prompt.Option { + options := []prompt.Option{ + prompt.OptionTitle(c.title()), + prompt.OptionPrefix(c.prefix), + prompt.OptionLivePrefix(func() (string, bool) { + return c.livePrefix, c.livePrefixEnabled + }), + + prompt.OptionSuggestionBGColor(prompt.DarkGray), + prompt.OptionPreviewSuggestionTextColor(prompt.DefaultColor), + + prompt.OptionCompletionWordSeparator(tarantoolWordSeparators), + + prompt.OptionAddASCIICodeBind( + // Move to one word left. + prompt.ASCIICodeBind{ + ASCIICode: controlLeftBytes, + Fn: prompt.GoLeftWord, + }, + // Move to one word right. + prompt.ASCIICodeBind{ + ASCIICode: controlRightBytes, + Fn: prompt.GoRightWord, + }, + ), + // Interrupt current unfinished expression. + prompt.OptionAddKeyBind( + prompt.KeyBind{ + Key: prompt.ControlC, + Fn: func(buf *prompt.Buffer) { + c.input = "" + c.livePrefixEnabled = false + fmt.Println("^C") + }, + }, + ), + + prompt.OptionDisableAutoHistory(), + prompt.OptionReverseSearch(), + } + + if c.impl.History != nil { + options = append(options, prompt.OptionHistory(c.impl.History.Commands())) + } + + return options +} diff --git a/cli/console/format.go b/cli/console/format.go new file mode 100644 index 000000000..d5d76caec --- /dev/null +++ b/cli/console/format.go @@ -0,0 +1,47 @@ +package console + +import ( + "fmt" + + "github.com/tarantool/tt/cli/formatter" +) + +// Format aggregate formatter options. +type Format struct { + // Mode specify how to formatting result. + Mode formatter.Format + // Opts options for Format. + Opts formatter.Opts +} + +// FormatAsTable return Format options for formatting outputs as table. +func FormatAsTable() Format { + return Format{ + Mode: formatter.TableFormat, + Opts: formatter.Opts{ + Graphics: true, + ColumnWidthMax: 0, + TableDialect: formatter.DefaultTableDialect, + }, + } +} + +func (f Format) print(data any) error { + fmt.Println("---") + if fo, ok := data.(Formatter); ok { + // First ensure that data object implemented Formatter interface. + fmt.Println(fo.Format(f)) + + } else if so, ok := data.(fmt.Stringer); ok { + // Then checking is it has String method. + fmt.Println(so.String()) + + } else if eo, ok := data.(error); ok { + // Then checking is it has Error method. + fmt.Println("Error:\n", eo.Error()) + + } else { + return fmt.Errorf("can't format type=%T", data) + } + return nil +} diff --git a/cli/console/formatter.go b/cli/console/formatter.go new file mode 100644 index 000000000..c536edb46 --- /dev/null +++ b/cli/console/formatter.go @@ -0,0 +1,7 @@ +package console + +// Formatter interface provide common interface for console Handlers to format execution results. +type Formatter interface { + // Format result data according fmt settings and return string for printing. + Format(fmt Format) string +} diff --git a/cli/console/handler.go b/cli/console/handler.go new file mode 100644 index 000000000..b14315e8d --- /dev/null +++ b/cli/console/handler.go @@ -0,0 +1,25 @@ +package console + +import "github.com/tarantool/go-prompt" + +// Handler is a auxiliary abstraction to isolate the console from +// the implementation of a particular instruction processor. +type Handler interface { + // Title return name of instruction processor instance. + Title() string + // Validate the input string. + Validate(input string) bool + // Complete checks the input and return available variants to continue typing. + Complete(input prompt.Document) []prompt.Suggest + // Execute accept input to perform actions defined by client implementation. + // The type of the resulting object can be anything, and no special processing is expected. + // It must provide one of the following interfaces: + // - Formatter (for the best case). + // - Stringer + // - error + // Otherwise, when displaying the object in the console, a message that the object + // cannot be rendered correctly will be displayed. + Execute(input string) any + // Close notify handler to terminate execution and close any opened streams. + Close() +} diff --git a/cli/console/history.go b/cli/console/history.go new file mode 100644 index 000000000..73174d52d --- /dev/null +++ b/cli/console/history.go @@ -0,0 +1,137 @@ +package console + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/tarantool/tt/cli/util" +) + +const ( + DefaultHistoryFileName = ".tarantool_history" + DefaultHistoryLines = 10000 +) + +// History implementation of active history handler. +type History struct { + filepath string + maxCommands int + commands []string + timestamps []int64 +} + +// NewHistory create/open specified file. +func NewHistory(file string, maxCommands int) (History, error) { + h := History{ + filepath: file, + maxCommands: maxCommands, + commands: make([]string, 0), + timestamps: make([]int64, 0), + } + err := h.load() + return h, err +} + +// DefaultHistoryFile create/open history file with default parameters. +func DefaultHistoryFile() (History, error) { + dir, err := util.GetHomeDir() + if err != nil { + return History{}, fmt.Errorf("failed to get home directory: %w", err) + } + file := filepath.Join(dir, DefaultHistoryFileName) + return NewHistory(file, DefaultHistoryLines) +} + +func (h *History) load() error { + if !util.IsRegularFile(h.filepath) { + return nil + } + rawLines, err := util.GetLastNLines(h.filepath, h.maxCommands) + if err != nil { + return err + } + + h.parseCells(rawLines) + return nil +} + +func (h *History) parseCells(lines []string) { + timeRecord := regexp.MustCompile(`^#\d+$`) + + // startPos is the first position of a timestamp. + startPos := -1 + for i, line := range lines { + if timeRecord.MatchString(line) { + startPos = i + break + } + } + if startPos == -1 { + // Read one line per command. + // Set the current timestamp for each command. + h.commands = lines + now := time.Now().Unix() + for range lines { + h.timestamps = append(h.timestamps, now) + } + return + } + + for startPos < len(lines) { + j := startPos + 1 + + // Move pointer to the next timestamp. + for j < len(lines) && !timeRecord.MatchString(lines[j]) { + j++ + } + + // Extract the current timestamp. + timestamp, err := strconv.ParseInt(lines[startPos][1:], 10, 0) + + if j != startPos+1 && err == nil { + h.timestamps = append(h.timestamps, timestamp) + h.commands = append(h.commands, strings.Join(lines[startPos+1:j], "\n")) + } + startPos = j + } +} + +// writeToFile writes console history to the file. +func (h *History) writeToFile() error { + buff := bytes.Buffer{} + for i, c := range h.commands { + buff.WriteString(fmt.Sprintf("#%d\n%s\n", h.timestamps[i], c)) + } + if err := os.WriteFile(h.filepath, buff.Bytes(), 0640); err != nil { + return fmt.Errorf("failed to write to history file: %s", err) + } + + return nil +} + +// AppendCommand insert new command to the history file. +// Implements HistoryKeeper.AppendCommand interface method. +func (h *History) AppendCommand(input string) { + h.commands = append(h.commands, input) + h.timestamps = append(h.timestamps, time.Now().Unix()) + if len(h.commands) > h.maxCommands { + h.commands = h.commands[1:] + h.timestamps = h.timestamps[1:] + } + h.writeToFile() +} + +// Commands implements HistoryKeeper.Commands interface method. +func (h *History) Commands() []string { + return h.commands +} + +// Close implements HistoryKeeper.Close interface method. +func (h *History) Close() { +} diff --git a/cli/console/history_keeper.go b/cli/console/history_keeper.go new file mode 100644 index 000000000..e077b61b5 --- /dev/null +++ b/cli/console/history_keeper.go @@ -0,0 +1,11 @@ +package console + +// HistoryKeeper introduce methods to keep command history in some external place. +type HistoryKeeper interface { + // AppendCommand add new entered command to storage. + AppendCommand(input string) + // Commands return list of saved commands. + Commands() []string + // Close method notifies the repository that there will be no new commands. + Close() +} diff --git a/cli/console/history_test.go b/cli/console/history_test.go new file mode 100644 index 000000000..f0147adc3 --- /dev/null +++ b/cli/console/history_test.go @@ -0,0 +1,106 @@ +package console_test + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tarantool/tt/cli/console" +) + +func TestNewHistory(t *testing.T) { + type args struct { + file string + maxCommands int + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + { + name: "Empty file", + args: args{"testdata/history0.info", 10000}, + want: []string{}, + wantErr: false, + }, + { + name: "Not empty file", + args: args{"testdata/history1.info", 10000}, + want: []string{ + "box.cfg{}", + "box.schema.space.create(\"test\")", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hist, err := console.NewHistory(tt.args.file, tt.args.maxCommands) + if (err != nil) != tt.wantErr { + t.Errorf("NewHistory() error = %v, wantErr %v", err, tt.wantErr) + return + } + cmds := hist.Commands() + if !reflect.DeepEqual(cmds, tt.want) { + fmt.Print(cmds) + t.Errorf("NewHistory() = %v, want %v", cmds, tt.want) + } + }) + } +} + +func TestHistory_AppendCommand(t *testing.T) { + tests := []struct { + name string + max int + commands []string + }{ + { + "test 10", + 3, + []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0"}, + }, + { + "test 3", + 3, + []string{"1", "2", "3"}, + }, + { + "test 1", + 3, + []string{"1"}, + }, + } + tmp, _ := os.MkdirTemp(os.TempDir(), "history_test*") + + // Write and ensure last command in buffer. + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := filepath.Join(tmp, fmt.Sprintf("history_%d.info", i)) + h, err := console.NewHistory(name, tt.max) + require.NoError(t, err) + for _, c := range tt.commands { + h.AppendCommand(c) + } + from := max(len(tt.commands)-tt.max, 0) + reflect.DeepEqual(tt.commands[from:], h.Commands()) + }) + } + + // Read previously created history data. + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := filepath.Join(tmp, fmt.Sprintf("history%d.info", i)) + h, err := console.NewHistory(name, tt.max) + require.NoError(t, err) + from := max(len(tt.commands)-tt.max, 0) + reflect.DeepEqual(tt.commands[from:], h.Commands()) + }) + } + os.RemoveAll(tmp) +} diff --git a/cli/console/testdata/history0.info b/cli/console/testdata/history0.info new file mode 100644 index 000000000..e69de29bb diff --git a/cli/console/testdata/history1.info b/cli/console/testdata/history1.info new file mode 100644 index 000000000..3e3eb774a --- /dev/null +++ b/cli/console/testdata/history1.info @@ -0,0 +1,4 @@ +#1724939703 +box.cfg{} +#1724939757 +box.schema.space.create("test") diff --git a/go.mod b/go.mod index 011c725ed..c7d80a89d 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/tarantool/cartridge-cli v0.0.0-20220605082730-53e6a5be9a61 github.com/tarantool/go-prompt v1.0.1 github.com/tarantool/go-tarantool v1.12.2 + github.com/tarantool/go-tarantool/v2 v2.2.0 github.com/tarantool/tt/lib/cluster v0.0.0 github.com/tarantool/tt/lib/integrity v0.0.0 github.com/vmihailenco/msgpack/v5 v5.3.5 diff --git a/go.sum b/go.sum index c8d7cfa8f..331099dfc 100644 --- a/go.sum +++ b/go.sum @@ -427,6 +427,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.4/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tarantool/go-iproto v1.1.0 h1:HULVOIHsiehI+FnHfM7wMDntuzUddO09DKqu2WnFQ5A= +github.com/tarantool/go-iproto v1.1.0/go.mod h1:LNCtdyZxojUed8SbOiYHoc3v9NvaZTB7p96hUySMlIo= github.com/tarantool/go-openssl v0.0.8-0.20230307065445-720eeb389195/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= github.com/tarantool/go-openssl v1.1.1 h1:qOCSjUXRLxlnh0e2G6sH50B3d/gYpscbY/opFqsIfaE= github.com/tarantool/go-openssl v1.1.1/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= @@ -436,6 +438,8 @@ github.com/tarantool/go-prompt v1.0.1 h1:88Yer6gCFylqGRrdWwikNFVbklRQsqKF7mycvGd github.com/tarantool/go-prompt v1.0.1/go.mod h1:9Vuvi60Bk+3yaXqgYaXNTpLbwPPaaEOeaUgpFW1jqTU= github.com/tarantool/go-tarantool v1.12.2 h1:u4g+gTOHNxbUDJv0EIUFkRurU/lTQSzWrz8o7bHVAqI= github.com/tarantool/go-tarantool v1.12.2/go.mod h1:QRiXv0jnxwgxHtr9ZmifSr/eRba76gTUBgp69pDMX1U= +github.com/tarantool/go-tarantool/v2 v2.2.0 h1:U7RDvWxPaPPecMppqVwfpTGnSJQ++Crg2l9cS/ztgp8= +github.com/tarantool/go-tarantool/v2 v2.2.0/go.mod h1:hKKeZeCP8Y8+U6ZFS32ot1jHV/n4WKVP4fjRAvQznMY= github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk=