Skip to content

Commit

Permalink
console: integrate aeon client with console
Browse files Browse the repository at this point in the history
Part of #1050
  • Loading branch information
dmyger committed Dec 23, 2024
1 parent 54ac3ae commit 4151819
Show file tree
Hide file tree
Showing 11 changed files with 410 additions and 58 deletions.
211 changes: 211 additions & 0 deletions cli/aeon/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package aeon

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"strings"
"time"

"github.com/apex/log"

"github.com/tarantool/go-prompt"
"github.com/tarantool/tt/cli/aeon/cmd"
"github.com/tarantool/tt/cli/aeon/pb"
"github.com/tarantool/tt/cli/connector"
"github.com/tarantool/tt/cli/console"
"github.com/tarantool/tt/cli/formatter"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)

type ResultType struct {
data map[string][]any
count int
}

type Client struct {
title string
conn *grpc.ClientConn
client pb.AeonRouterServiceClient
}

func makeAddress(ctx cmd.ConnectCtx) string {
if ctx.Network == connector.UnixNetwork {
if strings.HasPrefix(ctx.Address, "@") {
return "unix-abstract:" + (ctx.Address)[1:]
}
return "unix:" + ctx.Address
}
return ctx.Address
}

func getCertificate(args cmd.Ssl) []tls.Certificate {
if args.CertFile == "" && args.KeyFile == "" {
return []tls.Certificate{}
}
tls_cert, err := tls.LoadX509KeyPair(args.CertFile, args.KeyFile)
if err != nil {
log.Fatalf("Could not load client key pair: %v", err)
}
return []tls.Certificate{tls_cert}
}

func getTlsConfig(args cmd.Ssl) *tls.Config {
if args.CaFile == "" {
return &tls.Config{
ClientAuth: tls.NoClientCert,
}
}

ca, err := os.ReadFile(args.CaFile)
if err != nil {
log.Fatalf("Failed to read CA file: %v", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(ca) {
log.Fatal("Failed to append CA data")
}
return &tls.Config{
Certificates: getCertificate(args),
ClientAuth: tls.RequireAndVerifyClientCert,
RootCAs: certPool,
}
}

func getDialOpts(ctx cmd.ConnectCtx) grpc.DialOption {
var creds credentials.TransportCredentials
if ctx.Transport == cmd.TransportSsl {
creds = credentials.NewTLS(getTlsConfig(ctx.Ssl))
} else {
creds = insecure.NewCredentials()
}
return grpc.WithTransportCredentials(creds)
}

// NewAeonHandler create new grpc connection to Aeon server.
func NewAeonHandler(ctx cmd.ConnectCtx) *Client {
c := Client{title: ctx.Address}
target := makeAddress(ctx)
var err error
c.conn, err = grpc.NewClient(target, getDialOpts(ctx))
if err != nil {
log.Fatalf("Fail to dial: %v", err)
}
c.client = pb.NewAeonRouterServiceClient(c.conn)

if c.ping() {
log.Infof("Aeon responses at %q", target)
} else {
log.Fatalf("Can't ping to Aeon at %q", target)
}
return &c
}

func (c *Client) ping() bool {
log.Infof("Start ping aeon server")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

_, err := c.client.Ping(ctx, &pb.PingRequest{})
return err == nil
}

// Title implements console.Handler interface.
func (c *Client) Title() string {
return c.title
}

// Validate implements console.Handler interface.
func (c *Client) Validate(input string) bool {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

check, err := c.client.SQLCheck(ctx, &pb.SQLRequest{Query: input})
if err != nil {
return false
}

return check.Status == pb.SQLCheckStatus_SQL_QUERY_VALID
}

// Execute implements console.Handler interface.
func (c *Client) Execute(input string) any {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

resp, err := c.client.SQL(ctx, &pb.SQLRequest{Query: input})
if err != nil {
return err
}
return parseSQLResponse(resp)

}

// Stop implements console.Handler interface.
func (c *Client) Close() {
c.conn.Close()
}

// Complete implements console.Handler interface.
func (c *Client) Complete(input prompt.Document) []prompt.Suggest {
// TODO: waiting until there is support from Aeon side.
return nil
}

// parseSQLResponse returns result as table in map.
// Where keys is name of columns. And body is array of values.
// On any issue return an error.
func parseSQLResponse(resp *pb.SQLResponse) any {
if resp.Error != nil {
return fmt.Errorf("something wrong with SQL request: %s", resp.Error)
}
res := ResultType{
data: make(map[string][]any, len(resp.TupleFormat.Names)),
count: len(resp.Tuples),
}
// result := make(ResultType, len(resp.TupleFormat.Names))
rows := len(resp.Tuples)
for _, f := range resp.TupleFormat.Names {
res.data[f] = make([]any, 0, rows)
}

for r, row := range resp.Tuples {
for i, v := range row.Fields {
k := resp.TupleFormat.Names[i]
val, err := decodeValue(v)
if err != nil {
return fmt.Errorf("tuple %d can't decode %s: %w", r, v.String(), err)
}
res.data[k] = append(res.data[k], val)
}
}
return res
}

// asYaml prepare results for formatter.MakeOutput.
func (r ResultType) asYaml() string {
yaml := "---\n"
for i := range r.count {
mark := "-"
for k, v := range r.data {
if i < len(v) {
yaml += fmt.Sprintf("%s %s: %v\n", mark, k, v[i])
mark = " "
}
}
}
return yaml
}

// Format produce formatted string according required console.Format settings.
func (r ResultType) Format(f console.Format) string {
output, err := formatter.MakeOutput(f.Mode, r.asYaml(), f.Opts)
if err != nil {
return fmt.Sprintf("can't format output: %s;\nResults:\n%v", err, r)
}
return output
}
4 changes: 4 additions & 0 deletions cli/aeon/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ type ConnectCtx struct {
Ssl Ssl
// Transport is a connection mode.
Transport Transport
// Network is kind of transport layer.
Network string
// Address is a connection Url, unix socket address and etc.
Address string
}
102 changes: 102 additions & 0 deletions cli/aeon/decode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package aeon

import (
"fmt"
"time"

"github.com/google/uuid"
"github.com/tarantool/go-tarantool/v2/datetime"
"github.com/tarantool/go-tarantool/v2/decimal"
"github.com/tarantool/tt/cli/aeon/pb"
)

/*
decodeValue convert a value obtained from protobuf into a value that can be used as an
argument to Tarantool functions.
Copy from https://github.com/tarantool/aeon/blob/master/aeon/grpc/server/pb/decode.go
*/
func decodeValue(val *pb.Value) (any, error) {
switch casted := val.Kind.(type) {
case *pb.Value_UnsignedValue:
return val.GetUnsignedValue(), nil
case *pb.Value_StringValue:
return val.GetStringValue(), nil
case *pb.Value_NumberValue:
return val.GetNumberValue(), nil
case *pb.Value_IntegerValue:
return val.GetIntegerValue(), nil
case *pb.Value_BooleanValue:
return val.GetBooleanValue(), nil
case *pb.Value_VarbinaryValue:
return val.GetVarbinaryValue(), nil
case *pb.Value_DecimalValue:
decStr := val.GetDecimalValue()
res, err := decimal.MakeDecimalFromString(decStr)
if err != nil {
return nil, err
}
return res, nil
case *pb.Value_UuidValue:
uuidStr := val.GetUuidValue()
res, err := uuid.Parse(uuidStr)
if err != nil {
return nil, err
}
return res, nil
case *pb.Value_DatetimeValue:
sec := casted.DatetimeValue.Seconds
nsec := casted.DatetimeValue.Nsec
t := time.Unix(sec, nsec)
if len(casted.DatetimeValue.Location) > 0 {
locStr := casted.DatetimeValue.Location
loc, err := time.LoadLocation(locStr)
if err != nil {
return nil, err
}
t = t.In(loc)
}
res, err := datetime.MakeDatetime(t)
if err != nil {
return nil, err
}
return res, nil
case *pb.Value_IntervalValue:
res := datetime.Interval{
Year: casted.IntervalValue.Year,
Month: casted.IntervalValue.Month,
Week: casted.IntervalValue.Week,
Day: casted.IntervalValue.Day,
Hour: casted.IntervalValue.Hour,
Min: casted.IntervalValue.Min,
Sec: casted.IntervalValue.Sec,
Nsec: casted.IntervalValue.Nsec,
Adjust: datetime.Adjust(casted.IntervalValue.Adjust)}
return res, nil
case *pb.Value_ArrayValue:
array := val.GetArrayValue()
res := make([]any, len(array.Fields))
for k, v := range array.Fields {
field, err := decodeValue(v)
if err != nil {
return nil, err
}
res[k] = field
}
return res, nil
case *pb.Value_MapValue:
res := make(map[any]any, len(casted.MapValue.Fields))
for k, v := range casted.MapValue.Fields {
item, err := decodeValue(v)
if err != nil {
return nil, err
}
res[k] = item
}
return res, nil
case *pb.Value_NullValue:
return nil, nil
default:
return nil, fmt.Errorf("unsupported type for value")
}
}
Loading

0 comments on commit 4151819

Please sign in to comment.