diff --git a/go.mod b/go.mod index 739f84a9..36c7be58 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,8 @@ require ( k8s.io/kube-aggregator v0.31.1 k8s.io/kube-openapi v0.0.0-20240411171206-dc4e619f62f3 k8s.io/kubernetes v1.31.1 + k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 + modernc.org/sqlite v1.29.10 sigs.k8s.io/controller-runtime v0.19.0 ) @@ -137,12 +139,10 @@ require ( k8s.io/component-base v0.31.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kms v0.31.1 // indirect - k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.49.3 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect - modernc.org/sqlite v1.29.10 // indirect modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3 // indirect diff --git a/pkg/server/server.go b/pkg/server/server.go index 4a7829f0..a9711e77 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -82,7 +82,7 @@ type Options struct { AggregationSecretName string ClusterRegistry string ServerVersion string - // SQLCache enables the SQLite-based lasso caching mechanism + // SQLCache enables the SQLite-based caching mechanism SQLCache bool // ExtensionAPIServer enables an extension API server that will be served diff --git a/pkg/sqlcache/Readme.md b/pkg/sqlcache/Readme.md new file mode 100644 index 00000000..99da82a2 --- /dev/null +++ b/pkg/sqlcache/Readme.md @@ -0,0 +1,157 @@ +# SQL Cache + +## Sections +- [ListOptions Informer](#listoptions-informer) + - [List Options](#list-options) + - [ListOption Indexer](#listoptions-indexer) + - [SQL Store](#sql-store) + - [Partitions](#partitions) +- [How to Use](#how-to-use) +- [Technical Information](#technical-information) + - [SQL Tables](#sql-tables) + - [SQLite Driver](#sqlite-driver) + - [Connection Pooling](#connection-pooling) + - [Encryption Defaults](#encryption-defaults) + - [Indexed Fields](#indexed-fields) + - [ListOptions Behavior](#listoptions-behavior) + - [Troubleshooting Sqlite](#troubleshooting-sqlite) + + + +## ListOptions Informer +The main usable feature from the SQL cache is the ListOptions Informer. The ListOptionsInformer provides listing functionality, +like any other informer, but with a wider array of options. The options are configured by informer.ListOptions. + +### List Options +ListOptions includes the following: +* Match filters for indexed fields. Filters are for specifying the value a given field in an object should be in order to +be included in the list. Filters can be set to equals or not equals. Filters can be set to look for partial matches or +exact (strict) matches. Filters can be OR'd and AND'd with one another. Filters only work on fields that have been indexed. +* Primary field and secondary field sorting order. Can choose up to two fields to sort on. Sort order can be ascending +or descending. Default sorting is to sort on metadata.namespace in ascending first and then sort on metadata.name. +* Page size to specify how many items to include in a response. +* Page number to specify offset. For example, a page size of 50 and a page number of 2, will return items starting at +index 50. Index will be dependent on sort. Page numbers start at 1. + +### ListOptions Factory +The ListOptions Factory helps manage multiple ListOption Informers. A user can call Factory.InformerFor(), to create new +ListOptions informers if they do not exist and retrieve existing ones. + +### ListOptions Indexer +Like all other informers, the ListOptions informer uses an indexer to cache objects of the informer's type. A few features +set the ListOptions Indexer apart from others indexers: +* an on-disk store instead of an in-memory store. +* accepts list options backed by SQL queries for extended search/filter/sorting capability. +* AES GCM encryption using key hierarchy. + +### SQL Store +The SQL store is the main interface for interacting with the database. This store backs the indexer, and provides all +functionality required by the cache.Store interface. + +### Partitions +Partitions are constraints for ListOptionsInform ListByOptions() method that are separate from ListOptions. Partitions +are strict conditions that dictate which namespaces or names can be searched from. These overrule ListOptions and are +intended to be used as a way of enforcing RBAC. + +## How to Use +```go + package main + import( + "k8s.io/client-go/dynamic" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + ) + + func main() { + cacheFactory, err := factory.NewCacheFactory() + if err != nil { + panic(err) + } + // config should be some rest config created from kubeconfig + // there are other ways to create a config and any client that conforms to k8s.io/client-go/dynamic.ResourceInterface + // will work. + client, err := dynamic.NewForConfig(config) + if err != nil { + panic(err) + } + + fields := [][]string{{"metadata", "name"}, {"metadata", "namespace"}} + opts := &informer.ListOptions{} + // gvk should be of type k8s.io/apimachinery/pkg/runtime/schema.GroupVersionKind + c, err := cacheFactory.CacheFor(fields, client, gvk) + if err != nil { + panic(err) + } + + // continueToken will just be an offset that can be used in Resume on a subsequent request to continue + // to next page + list, continueToken, err := c.ListByOptions(apiOp.Context(), opts, partitions, namespace) + if err != nil { + panic(err) + } + } +``` + +## Technical Information + +### SQL Tables +There are three tables that are created for the ListOption informer: +* object table - this contains objects, including all their fields, as blobs. These blobs may be encrypted. +* fields table - this contains specific fields of value for objects. These are specified on informer create and are fields +that it is desired to filter or order on. +* indices table - the indices table stores indexes created and objects' values for each index. This backs the generic indexer +that contains the functionality needed to conform to cache.Indexer. + +### SQLite Driver +There are multiple SQLite drivers that this package could have used. One of the most, if not the most, popular SQLite golang +drivers is [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3). This driver is not being used because it requires enabling +the cgo option when compiling and at the moment steve's main consumer, rancher, does not compile with cgo. We did not want +the SQL informer to be the sole driver in switching to using cgo. Instead, modernc's driver which is in pure golang. Side-by-side +comparisons can be found indicating the cgo version is, as expected, more performant. If in the future it is deemed worthwhile +then the driver can be easily switched by replacing the empty import in `pkg/cache/sql/store` from `_ "modernc.org/sqlite"` to `_ "github.com/mattn/go-sqlite3"`. + +### Connection Pooling +While working with the `database/sql` package for go, it is important to understand how sql.Open() and other methods manage +connections. Open starts a connection pool; that is to say after calling open once, there may be anywhere from zero to many +connections attached to a sql.Connection. `database/sql` manages this connection pool under the hood. In most cases, an +application only need one sql.Connection, although sometimes application use two: one for writes, the other for reads. To +read more about the `sql` package's connection pooling read [Managing connections](https://go.dev/doc/database/manage-connections). + +The use of connection pooling and the fact that steve potentially has many go routines accessing the same connection pool, +means we have to be careful with writes. Exclusively using sql transaction to write helps ensure safety. To read more about +sql transactions read SQLite's [Transaction docs](https://www.sqlite.org/lang_transaction.html). + +### Encryption Defaults +By default only specified types are encrypted. These types are hard-coded and defined by defaultEncryptedResourceTypes +in `pkg/cache/sql/informer/factory/informer_factory.go`. To enabled encryption for all types, set the ENV variable +`CATTLE_ENCRYPT_CACHE_ALL` to "true". + +The key size used is 256 bits. Data-encryption-keys are stored in the object table and are rotated every 150,000 writes. + +### Indexed Fields +Filtering and sorting only work on indexed fields. These fields are defined when using `CacheFor`. Objects will +have the following indexes by default: +* Fields in informer.defaultIndexedFields +* Fields passed to InformerFor() + +### ListOptions Behavior +Defaults: +* Sort.PrimaryField: `metadata.namespace` +* Sort.SecondaryField: `metadata.name` +* Sort.PrimaryOrder: `ASC` (ascending) +* Sort.SecondaryOrder: `ASC` (ascending) +* All filters have partial matching set to false by default + +There are some uncommon ways someone could use ListOptions where it would be difficult to predict what the result would be. +Below is a non-exhaustive list of some of these cases and what the behavior is: +* Setting Pagination.Page but not Pagination.PageSize will cause Page to be ignored +* Setting Sort.SecondaryField only will sort as though it was Sort.PrimaryField. Sort.SecondaryOrder will still be applied +and Sort.PrimaryOrder will be ignored + +### Writing Secure Queries +Values should be supplied to SQL queries using placeholders, read [Avoiding SQL Injection Risk](https://go.dev/doc/database/sql-injection). Any other portions +of a query that may be user supplied, such as columns, should be carefully validated against a fixed set of acceptable values. + +### Troubleshooting SQLite +A useful tool for troubleshooting the database files is the sqlite command line tool. Another useful tool is the goland +sqlite plugin. Both of these tools can be used with the database files. diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go new file mode 100644 index 00000000..dffbb6dd --- /dev/null +++ b/pkg/sqlcache/db/client.go @@ -0,0 +1,374 @@ +/* +Package db offers client struct and functions to interact with database connection. It provides encrypting, decrypting, +and a way to reset the database. +*/ +package db + +import ( + "bytes" + "context" + "database/sql" + "encoding/gob" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + + "github.com/pkg/errors" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" + + // needed for drivers + _ "modernc.org/sqlite" +) + +const ( + // InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve + InformerObjectCacheDBPath = "informer_object_cache.db" + + informerObjectCachePerms fs.FileMode = 0o600 +) + +// Client is a database client that provides encrypting, decrypting, and database resetting. +type Client struct { + conn Connection + connLock sync.RWMutex + encryptor Encryptor + decryptor Decryptor +} + +// Connection represents a connection pool. +type Connection interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) + Exec(query string, args ...any) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Close() error +} + +// Closable Closes an underlying connection and returns an error on failure. +type Closable interface { + Close() error +} + +// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them. +type Rows interface { + Next() bool + Err() error + Close() error + Scan(dest ...any) error +} + +// QueryError encapsulates an error while executing a query +type QueryError struct { + QueryString string + Err error +} + +// Error returns a string representation of this QueryError +func (e *QueryError) Error() string { + return "while executing query: " + e.QueryString + " got error: " + e.Err.Error() +} + +// Unwrap returns the underlying error +func (e *QueryError) Unwrap() error { + return e.Err +} + +// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed. +type TXClient interface { + StmtExec(stmt transaction.Stmt, args ...any) error + Exec(stmt string, args ...any) error + Commit() error + Stmt(stmt *sql.Stmt) transaction.Stmt + Cancel() error +} + +// Encryptor encrypts data with a key which is rotated to avoid wear-out. +type Encryptor interface { + // Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. + Encrypt([]byte) ([]byte, []byte, uint32, error) +} + +// Decryptor decrypts data previously encrypted by Encryptor. +type Decryptor interface { + // Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error. + Decrypt([]byte, []byte, uint32) ([]byte, error) +} + +// NewClient returns a Client. If the given connection is nil then a default one will be created. +func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) { + client := &Client{ + encryptor: encryptor, + decryptor: decryptor, + } + if c != nil { + client.conn = c + return client, nil + } + err := client.NewConnection() + if err != nil { + return nil, err + } + + return client, nil +} + +// Prepare prepares the given string into a sql statement on the client's connection. +func (c *Client) Prepare(stmt string) *sql.Stmt { + c.connLock.RLock() + defer c.connLock.RUnlock() + prepared, err := c.conn.Prepare(stmt) + if err != nil { + panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err)) + } + return prepared +} + +// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried +// given a sqlite busy error. +func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + return stmt.QueryContext(ctx, params...) +} + +// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant +// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection. +func (c *Client) CloseStmt(closable Closable) error { + return closable.Close() +} + +// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type, +// and returns a slice of those objects. +func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + var result []any + for rows.Next() { + data, err := c.decryptScan(rows, shouldDecrypt) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + singleResult, err := fromBytes(data, typ) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + result = append(result, singleResult.Elem().Interface()) + } + err := rows.Err() + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return nil, err + } + + return result, nil +} + +// ReadStrings scans the given rows into strings, and then returns the strings as a slice. +func (c *Client) ReadStrings(rows Rows) ([]string, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + var result []string + for rows.Next() { + var key string + err := rows.Scan(&key) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + result = append(result, key) + } + err := rows.Err() + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return nil, err + } + + return result, nil +} + +// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries) +func (c *Client) ReadInt(rows Rows) (int, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + if !rows.Next() { + return 0, closeRowsOnError(rows, sql.ErrNoRows) + } + + var result int + err := rows.Scan(&result) + if err != nil { + return 0, closeRowsOnError(rows, err) + } + + err = rows.Err() + if err != nil { + return 0, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return 0, err + } + + return result, nil +} + +// BeginTx attempts to begin a transaction. +// If forWriting is true, this method blocks until all other concurrent forWriting +// transactions have either committed or rolled back. +// If forWriting is false, it is assumed the returned transaction will exclusively +// be used for DQL (e.g. SELECT) queries. +// Not respecting the above rule might result in transactions failing with unexpected +// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked"). +// See discussion in https://github.com/rancher/lasso/pull/98 for details +func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + // note: this assumes _txlock=immediate in the connection string, see NewConnection + sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: !forWriting, + }) + if err != nil { + return nil, err + } + return transaction.NewClient(sqlTx), nil +} + +func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) { + var data, dataNonce sql.RawBytes + var kid uint32 + err := rows.Scan(&data, &dataNonce, &kid) + if err != nil { + return nil, err + } + if c.decryptor != nil && shouldDecrypt { + decryptedData, err := c.decryptor.Decrypt(data, dataNonce, kid) + if err != nil { + return nil, err + } + return decryptedData, nil + } + return data, nil +} + +// Upsert used to be called upsertEncrypted in store package before move +func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error { + objBytes := toBytes(obj) + var dataNonce []byte + var err error + var kid uint32 + if c.encryptor != nil && shouldEncrypt { + objBytes, dataNonce, kid, err = c.encryptor.Encrypt(objBytes) + if err != nil { + return err + } + } + + return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid) +} + +// toBytes encodes an object to a byte slice +func toBytes(obj any) []byte { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(obj) + if err != nil { + panic(fmt.Errorf("error while gobbing object: %w", err)) + } + bb := buf.Bytes() + return bb +} + +// fromBytes decodes an object from a byte slice +func fromBytes(buf sql.RawBytes, typ reflect.Type) (reflect.Value, error) { + dec := gob.NewDecoder(bytes.NewReader(buf)) + singleResult := reflect.New(typ) + err := dec.DecodeValue(singleResult) + return singleResult, err +} + +// closeRowsOnError closes the sql.Rows object and wraps errors if needed +func closeRowsOnError(rows Rows, err error) error { + ce := rows.Close() + if ce != nil { + return fmt.Errorf("error in closing rows while handling %s: %w", err.Error(), ce) + } + + return err +} + +// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently +// creates new files. +func (c *Client) NewConnection() error { + c.connLock.Lock() + defer c.connLock.Unlock() + if c.conn != nil { + err := c.conn.Close() + if err != nil { + return err + } + } + err := os.RemoveAll(InformerObjectCacheDBPath) + if err != nil { + return err + } + + // Set the permissions in advance, because we can't control them if + // the file is created by a sql.Open call instead. + if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil { + return nil + } + + sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+ + // open SQLite file in read-write mode, creating it if it does not exist + "mode=rwc&"+ + // use the WAL journal mode for consistency and efficiency + "_pragma=journal_mode=wal&"+ + // do not even attempt to attain durability. Database is thrown away at pod restart + "_pragma=synchronous=off&"+ + // do check foreign keys and honor ON DELETE CASCADE + "_pragma=foreign_keys=on&"+ + // if two transactions want to write at the same time, allow 2 minutes for the first to complete + // before baling out + "_pragma=busy_timeout=120000&"+ + // default to IMMEDIATE mode for transactions. Setting this parameter is the only current way + // to be able to switch between DEFERRED and IMMEDIATE modes in modernc.org/sqlite's implementation + // of BeginTx + "_txlock=immediate") + if err != nil { + return err + } + + c.conn = sqlDB + return nil +} + +// This acts like "touch" for both existing files and non-existing files. +// permissions. +// +// It's created with the correct perms, and if the file already exists, it will +// be chmodded to the correct perms. +func touchFile(filename string, perms fs.FileMode) error { + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, perms) + if err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + + return os.Chmod(filename, perms) +} diff --git a/pkg/sqlcache/db/client_test.go b/pkg/sqlcache/db/client_test.go new file mode 100644 index 00000000..8b7951f1 --- /dev/null +++ b/pkg/sqlcache/db/client_test.go @@ -0,0 +1,667 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "io/fs" + "math" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +// Mocks for this test are generated with the following command. +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx + +type testStoreObject struct { + Id string + Val string +} + +func TestNewClient(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + expectedClient := &Client{ + conn: c, + encryptor: e, + decryptor: d, + } + client, err := NewClient(c, e, d) + assert.Nil(t, err) + assert.Equal(t, expectedClient, client) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryForRows(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + client := SetupClient(t, c, nil, nil) + s := NewMockStmt(gomock.NewController(t)) + ctx := context.TODO() + r := &sql.Rows{} + s.EXPECT().QueryContext(ctx).Return(r, nil) + rows, err := client.QueryForRows(ctx, s) + assert.Nil(t, err) + assert.Equal(t, r, rows) + }, + }) + tests = append(tests, testCase{description: "Query rows with params, QueryContext() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + client := SetupClient(t, c, nil, nil) + s := NewMockStmt(gomock.NewController(t)) + ctx := context.TODO() + s.EXPECT().QueryContext(ctx).Return(nil, fmt.Errorf("error")) + _, err := client.QueryForRows(ctx, s) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryObjects(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + var keyId uint32 = math.MaxUint32 + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query objects, with one row, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes(testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.Nil(t, err) + assert.Equal(t, 1, len(items)) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a decrypt error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes( + testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(nil, fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a Scan() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes(testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with no rows, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.Nil(t, err) + assert.Equal(t, 0, len(items)) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryStrings(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "ReadStrings(), with one row, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadStrings(r) + assert.Nil(t, err) + assert.Equal(t, 1, len(items)) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and Scan error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with one row, and Err() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with one row, and Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with no rows, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadStrings(r) + assert.Nil(t, err) + assert.Equal(t, 0, len(items)) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestReadInt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testResult := 42 + tests = append(tests, testCase{description: "One row, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + p := a[0].(*int) + *p = testResult + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + result, err := client.ReadInt(r) + assert.Nil(t, err) + assert.Equal(t, 42, result) + }, + }) + tests = append(tests, testCase{description: "One row, Scan error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "One row, Err() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + a[0] = testResult + }) + r.EXPECT().Err().Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "One row, Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + a[0] = testResult + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "No rows error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.ErrorIs(t, err, sql.ErrNoRows) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestBegin(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + sqlTx := &sql.Tx{} + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil) + client := SetupClient(t, c, e, d) + txC, err := client.BeginTx(context.Background(), false) + assert.Nil(t, err) + assert.NotNil(t, txC) + }, + }) + tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + sqlTx := &sql.Tx{} + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil) + client := SetupClient(t, c, e, d) + txC, err := client.BeginTx(context.Background(), true) + assert.Nil(t, err) + assert.NotNil(t, txC) + }, + }) + tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.BeginTx(context.Background(), false) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestUpsert(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + var keyID uint32 = 5 + + // Tests with shouldEncryptSet to true + tests = append(tests, testCase{description: "Upsert() with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + testObjBytes := toBytes(testObject) + testByteValue := []byte("something") + e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with Encrypt() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + testObjBytes := toBytes(testObject) + e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error")) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with StmtExec() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + testObjBytes := toBytes(testObject) + testByteValue := []byte("something") + e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error")) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with no errors and shouldEncrypt false", test: func(t *testing.T) { + c := SetupMockConnection(t) + d := SetupMockDecryptor(t) + e := SetupMockEncryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + var testByteValue []byte + testObjBytes := toBytes(testObject) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, false) + assert.Nil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestPrepare(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "Prepare() with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + sqlStmt := &sql.Stmt{} + c.EXPECT().Prepare("something").Return(sqlStmt, nil) + + stmt := client.Prepare("something") + assert.Equal(t, sqlStmt, stmt) + }, + }) + tests = append(tests, testCase{description: "Prepare() with Connection Prepare() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + c.EXPECT().Prepare("something").Return(nil, fmt.Errorf("error")) + + assert.Panics(t, func() { client.Prepare("something") }) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestNewConnection(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "NewConnection replaces file", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + c.EXPECT().Close().Return(nil) + + err := client.NewConnection() + assert.Nil(t, err) + + // Create a transaction to ensure that the file is written to disk. + txC, err := client.BeginTx(context.Background(), false) + assert.NoError(t, err) + assert.NoError(t, txC.Commit()) + + assert.FileExists(t, InformerObjectCacheDBPath) + assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600) + + err = os.Remove(InformerObjectCacheDBPath) + if err != nil { + assert.Fail(t, "could not remove object cache path after test") + } + }, + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestCommit(t *testing.T) { + +} + +func TestRollback(t *testing.T) { + +} + +func SetupMockConnection(t *testing.T) *MockConnection { + mockC := NewMockConnection(gomock.NewController(t)) + return mockC +} + +func SetupMockEncryptor(t *testing.T) *MockEncryptor { + mockE := NewMockEncryptor(gomock.NewController(t)) + return mockE +} + +func SetupMockDecryptor(t *testing.T) *MockDecryptor { + MockD := NewMockDecryptor(gomock.NewController(t)) + return MockD +} + +func SetupMockRows(t *testing.T) *MockRows { + MockR := NewMockRows(gomock.NewController(t)) + return MockR +} + +func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client { + c, _ := NewClient(connection, encryptor, decryptor) + return c +} + +func TestTouchFile(t *testing.T) { + t.Run("File doesn't exist before", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "test1.txt") + assert.NoError(t, touchFile(filename, 0600)) + assertFileHasPermissions(t, filename, 0600) + }) + + t.Run("File exists with different permissions", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "test2.txt") + assert.NoError(t, os.WriteFile(filename, []byte("test"), 0644)) + assert.NoError(t, touchFile(filename, 0600)) + assertFileHasPermissions(t, filename, 0600) + }) +} + +func assertFileHasPermissions(t *testing.T, fname string, wantPerms fs.FileMode) bool { + t.Helper() + info, err := os.Lstat(fname) + if err != nil { + if os.IsNotExist(err) { + return assert.Fail(t, fmt.Sprintf("unable to find file %q", fname)) + } + return assert.Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", fname, err)) + } + + // Stringifying the perms makes it easier to read than a Hex comparison. + assert.Equal(t, wantPerms.String(), info.Mode().Perm().String()) + + return true +} diff --git a/pkg/sqlcache/db/db_mocks_test.go b/pkg/sqlcache/db/db_mocks_test.go new file mode 100644 index 00000000..54199ba4 --- /dev/null +++ b/pkg/sqlcache/db/db_mocks_test.go @@ -0,0 +1,370 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient +// + +// Package db is a generated GoMock package. +package db + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} + +// MockConnection is a mock of Connection interface. +type MockConnection struct { + ctrl *gomock.Controller + recorder *MockConnectionMockRecorder +} + +// MockConnectionMockRecorder is the mock recorder for MockConnection. +type MockConnectionMockRecorder struct { + mock *MockConnection +} + +// NewMockConnection creates a new mock instance. +func NewMockConnection(ctrl *gomock.Controller) *MockConnection { + mock := &MockConnection{ctrl: ctrl} + mock.recorder = &MockConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockConnection) BeginTx(arg0 context.Context, arg1 *sql.TxOptions) (*sql.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(*sql.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockConnectionMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockConnection)(nil).BeginTx), arg0, arg1) +} + +// Close mocks base method. +func (m *MockConnection) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close)) +} + +// Exec mocks base method. +func (m *MockConnection) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockConnectionMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...) +} + +// Prepare mocks base method. +func (m *MockConnection) Prepare(arg0 string) (*sql.Stmt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockConnectionMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockConnection)(nil).Prepare), arg0) +} + +// MockEncryptor is a mock of Encryptor interface. +type MockEncryptor struct { + ctrl *gomock.Controller + recorder *MockEncryptorMockRecorder +} + +// MockEncryptorMockRecorder is the mock recorder for MockEncryptor. +type MockEncryptorMockRecorder struct { + mock *MockEncryptor +} + +// NewMockEncryptor creates a new mock instance. +func NewMockEncryptor(ctrl *gomock.Controller) *MockEncryptor { + mock := &MockEncryptor{ctrl: ctrl} + mock.recorder = &MockEncryptorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEncryptor) EXPECT() *MockEncryptorMockRecorder { + return m.recorder +} + +// Encrypt mocks base method. +func (m *MockEncryptor) Encrypt(arg0 []byte) ([]byte, []byte, uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(uint32) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockEncryptorMockRecorder) Encrypt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptor)(nil).Encrypt), arg0) +} + +// MockDecryptor is a mock of Decryptor interface. +type MockDecryptor struct { + ctrl *gomock.Controller + recorder *MockDecryptorMockRecorder +} + +// MockDecryptorMockRecorder is the mock recorder for MockDecryptor. +type MockDecryptorMockRecorder struct { + mock *MockDecryptor +} + +// NewMockDecryptor creates a new mock instance. +func NewMockDecryptor(ctrl *gomock.Controller) *MockDecryptor { + mock := &MockDecryptor{ctrl: ctrl} + mock.recorder = &MockDecryptorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDecryptor) EXPECT() *MockDecryptorMockRecorder { + return m.recorder +} + +// Decrypt mocks base method. +func (m *MockDecryptor) Decrypt(arg0, arg1 []byte, arg2 uint32) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2) +} + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} diff --git a/pkg/sqlcache/db/transaction/transaction.go b/pkg/sqlcache/db/transaction/transaction.go new file mode 100644 index 00000000..68b476ea --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction.go @@ -0,0 +1,92 @@ +/* +Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures. +The use of this package simplifies testing for callers by making the underlying transaction mock-able. +*/ +package transaction + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +// Client provides a way to interact with the underlying sql transaction. +type Client struct { + sqlTx SQLTx +} + +// SQLTx represents a sql transaction +type SQLTx interface { + Exec(query string, args ...any) (sql.Result, error) + Stmt(stmt *sql.Stmt) *sql.Stmt + Commit() error + Rollback() error +} + +// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type +// because we are able to mock its outputs and do not need an actual connection. +type Stmt interface { + Exec(args ...any) (sql.Result, error) + Query(args ...any) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) +} + +// NewClient returns a Client with the given transaction assigned. +func NewClient(tx SQLTx) *Client { + return &Client{sqlTx: tx} +} + +// Commit commits the transaction and then unlocks the database. +func (c *Client) Commit() error { + return c.sqlTx.Commit() +} + +// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec() +// returns an error. +func (c *Client) Exec(stmt string, args ...any) error { + _, err := c.sqlTx.Exec(stmt, args...) + if err != nil { + return c.rollback(c.sqlTx, err) + } + return nil +} + +// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned +// here to aid in testing callers by providing a way to configure the statement's behavior. +func (c *Client) Stmt(stmt *sql.Stmt) Stmt { + s := c.sqlTx.Stmt(stmt) + return s +} + +// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The +// transaction is rolled back if Stmt.Exec() returns an error. +func (c *Client) StmtExec(stmt Stmt, args ...any) error { + _, err := stmt.Exec(args...) + if err != nil { + logrus.Debugf("StmtExec failed: query %s, args: %s, err: %s", stmt, args, err) + return c.rollback(c.sqlTx, err) + } + return nil +} + +// rollback handles rollbacks and wraps errors if needed +func (c *Client) rollback(tx SQLTx, err error) error { + rerr := tx.Rollback() + if rerr != nil { + return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr) + } + return errors.Wrapf(err, "Encountered error, successfully rolled back") +} + +// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned +// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too +// late. +func (c *Client) Cancel() error { + rerr := c.sqlTx.Rollback() + if rerr != sql.ErrTxDone { + return rerr + } + return nil +} diff --git a/pkg/sqlcache/db/transaction/transaction_mocks_test.go b/pkg/sqlcache/db/transaction/transaction_mocks_test.go new file mode 100644 index 00000000..0d7fdaa7 --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction_mocks_test.go @@ -0,0 +1,184 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +// + +// Package transaction is a generated GoMock package. +package transaction + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} + +// MockSQLTx is a mock of SQLTx interface. +type MockSQLTx struct { + ctrl *gomock.Controller + recorder *MockSQLTxMockRecorder +} + +// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. +type MockSQLTxMockRecorder struct { + mock *MockSQLTx +} + +// NewMockSQLTx creates a new mock instance. +func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { + mock := &MockSQLTx{ctrl: ctrl} + mock.recorder = &MockSQLTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockSQLTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) +} + +// Rollback mocks base method. +func (m *MockSQLTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) +} + +// Stmt mocks base method. +func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/db/transaction/transaction_test.go b/pkg/sqlcache/db/transaction/transaction_test.go new file mode 100644 index 00000000..0ede5d2e --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction_test.go @@ -0,0 +1,182 @@ +package transaction + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx + +func TestNewClient(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + c := NewClient(tx) + assert.Equal(t, tx, c.sqlTx) +} + +func TestCommit(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + tx.EXPECT().Commit().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.Commit() + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + tx.EXPECT().Commit().Return(fmt.Errorf("error")) + c := &Client{ + sqlTx: tx, + } + err := c.Commit() + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestExec(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(fmt.Errorf("error")) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestStmt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := &sql.Stmt{} + var returnedTXStmt *sql.Stmt + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Stmt(stmt).Return(returnedTXStmt) + c := &Client{ + sqlTx: tx, + } + returnedStmt := c.Stmt(stmt) + // whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it + // won't be equal to an empty struct + assert.Equal(t, returnedTXStmt, returnedStmt) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestStmtExec(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, nil) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(fmt.Errorf("error2")) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/db/transaction_mocks_test.go b/pkg/sqlcache/db/transaction_mocks_test.go new file mode 100644 index 00000000..1cac5caf --- /dev/null +++ b/pkg/sqlcache/db/transaction_mocks_test.go @@ -0,0 +1,184 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +// + +// Package db is a generated GoMock package. +package db + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} + +// MockSQLTx is a mock of SQLTx interface. +type MockSQLTx struct { + ctrl *gomock.Controller + recorder *MockSQLTxMockRecorder +} + +// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. +type MockSQLTxMockRecorder struct { + mock *MockSQLTx +} + +// NewMockSQLTx creates a new mock instance. +func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { + mock := &MockSQLTx{ctrl: ctrl} + mock.recorder = &MockSQLTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockSQLTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) +} + +// Rollback mocks base method. +func (m *MockSQLTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) +} + +// Stmt mocks base method. +func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/db/utility.go b/pkg/sqlcache/db/utility.go new file mode 100644 index 00000000..a8f84d29 --- /dev/null +++ b/pkg/sqlcache/db/utility.go @@ -0,0 +1,8 @@ +package db + +import "strings" + +// Sanitize returns a string that can be used in SQL as a name +func Sanitize(s string) string { + return strings.ReplaceAll(s, "\"", "") +} diff --git a/pkg/sqlcache/encryption/encrypt.go b/pkg/sqlcache/encryption/encrypt.go new file mode 100644 index 00000000..a7783ac9 --- /dev/null +++ b/pkg/sqlcache/encryption/encrypt.go @@ -0,0 +1,168 @@ +/* +Package encryption provides encryption and decryption functions, while +abstracting away key management concerns. +Uses AES-GCM encryption, with key rotation, keeping keys in memory. +*/ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "sync" + + "github.com/pkg/errors" +) + +var ( + ErrKeyNotFound = errors.New("data key not found") + // maxWriteCount holds the maximum amount of times the active key can be + // used, prior to it being rotated. 2^32 is the currently recommended key + // wear-out params by NIST for AES-GCM using random nonces. + maxWriteCount int64 = 1 << 32 +) + +const ( + keySize = 32 // 32 for AES-256 +) + +// Manager uses AES-GCM encryption and keeps in memory the data encryption +// keys. The active encryption key is automatically rotated once it has been +// used over a certain amount of times - defined by maxWriteCount. +type Manager struct { + dataKeys [][]byte + activeKeyCounter int64 + + // lock works as the mutual exclusion lock for dataKeys. + lock sync.RWMutex + // counterLock works as the mutual exclusion lock for activeKeyCounter. + counterLock sync.Mutex +} + +// NewManager returns Manager, which satisfies db.Encryptor and db.Decryptor +func NewManager() (*Manager, error) { + m := &Manager{ + dataKeys: [][]byte{}, + } + m.newDataEncryptionKey() + + return m, nil +} + +// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. +func (m *Manager) Encrypt(data []byte) ([]byte, []byte, uint32, error) { + dek, keyID, err := m.fetchActiveDataKey() + if err != nil { + return nil, nil, 0, err + } + aead, err := createGCMCypher(dek) + if err != nil { + return nil, nil, 0, err + } + edata, nonce, err := encrypt(aead, data) + if err != nil { + return nil, nil, 0, err + } + return edata, nonce, keyID, nil +} + +// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error. +func (m *Manager) Decrypt(edata, nonce []byte, keyID uint32) ([]byte, error) { + dek, err := m.key(keyID) + if err != nil { + return nil, err + } + + aead, err := createGCMCypher(dek) + if err != nil { + return nil, errors.Wrap(err, "failed to create GCMCypher from DEK") + } + data, err := aead.Open(nil, nonce, edata, nil) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to decrypt data using keyid %d", keyID)) + } + return data, nil +} + +func encrypt(aead cipher.AEAD, data []byte) ([]byte, []byte, error) { + if aead == nil { + return nil, nil, fmt.Errorf("aead is nil, cannot encrypt data") + } + nonce := make([]byte, aead.NonceSize()) + _, err := rand.Read(nonce) + if err != nil { + return nil, nil, err + } + sealed := aead.Seal(nil, nonce, data, nil) + return sealed, nonce, nil +} + +func createGCMCypher(key []byte) (cipher.AEAD, error) { + b, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(b) + if err != nil { + return nil, err + } + return aead, nil +} + +// fetchActiveDataKey returns the current data key and its key ID. +// Each call results in activeKeyCounter being incremented by 1. When the +// the activeKeyCounter exceeds maxWriteCount, the active data key is +// rotated - before being returned. +func (m *Manager) fetchActiveDataKey() ([]byte, uint32, error) { + m.counterLock.Lock() + defer m.counterLock.Unlock() + + m.activeKeyCounter++ + if m.activeKeyCounter >= maxWriteCount { + return m.newDataEncryptionKey() + } + + return m.activeKey() +} + +func (m *Manager) newDataEncryptionKey() ([]byte, uint32, error) { + dek := make([]byte, keySize) + _, err := rand.Read(dek) + if err != nil { + return nil, 0, err + } + + m.lock.Lock() + defer m.lock.Unlock() + + m.activeKeyCounter = 1 + + m.dataKeys = append(m.dataKeys, dek) + keyID := uint32(len(m.dataKeys) - 1) + + return dek, keyID, nil +} + +func (m *Manager) activeKey() ([]byte, uint32, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + nk := len(m.dataKeys) + if nk == 0 { + return nil, 0, ErrKeyNotFound + } + keyID := uint32(nk - 1) + + return m.dataKeys[keyID], keyID, nil +} + +func (m *Manager) key(keyID uint32) ([]byte, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.dataKeys) <= int(keyID) { + return nil, fmt.Errorf("%w: %v", ErrKeyNotFound, keyID) + } + return m.dataKeys[keyID], nil +} diff --git a/pkg/sqlcache/encryption/encrypt_test.go b/pkg/sqlcache/encryption/encrypt_test.go new file mode 100644 index 00000000..46f3300a --- /dev/null +++ b/pkg/sqlcache/encryption/encrypt_test.go @@ -0,0 +1,327 @@ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewManager(t *testing.T) { + m, err := NewManager() + if err != nil { + t.FailNow() + } + assert.NotNil(t, m) +} + +func TestEncrypt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + var tests []testCase + + tests = append(tests, testCase{description: "test encrypt with arbitrary initial key", test: func(t *testing.T) { + testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181} + + m, err := NewManager() + require.Nil(t, err) + + m.dataKeys[0] = testDEK + + testData := []byte("something") + cipherText, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + dek := m.dataKeys[keyID] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + decryptedData, err := aead.Open(nil, nonce, cipherText, nil) + + require.Nil(t, err) + assert.Equal(t, testData, decryptedData) + }}) + tests = append(tests, testCase{description: "test encrypt without arbitrary initial key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipherText, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + dek := m.dataKeys[keyID] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + decryptedData, err := aead.Open(nil, nonce, cipherText, nil) + + require.Nil(t, err) + assert.Equal(t, testData, decryptedData) + }}) + tests = append(tests, testCase{description: "test encrypt: same data yield different cipher/nonce pair", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipher1, nonce1, keyID1, err := m.Encrypt(testData) + require.Nil(t, err) + assert.Len(t, cipher1, 25) + assert.Len(t, nonce1, 12) + assert.NotEmpty(t, cipher1) + assert.NotEmpty(t, nonce1) + + cipher2, nonce2, keyID2, err := m.Encrypt(testData) + require.Nil(t, err) + + assert.Equal(t, keyID1, keyID2) + assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher") + assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce") + }}) + tests = append(tests, testCase{description: "test encrypt with key rotation", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipher1, nonce1, keyID1, err := m.Encrypt(testData) + require.Nil(t, err) + assert.Len(t, cipher1, 25) + assert.Len(t, nonce1, 12) + assert.NotEmpty(t, cipher1) + assert.NotEmpty(t, nonce1) + + m.activeKeyCounter += maxWriteCount + + cipher2, nonce2, keyID2, err := m.Encrypt(testData) + require.Nil(t, err) + + assert.Equal(t, int64(1), m.activeKeyCounter) + assert.NotEqual(t, keyID1, keyID2) + assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher") + assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce") + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestDecrypt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + var tests []testCase + + tests = append(tests, testCase{description: "test decrypt with arbitrary key", test: func(t *testing.T) { + testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181} + + m, err := NewManager() + require.Nil(t, err) + + m.dataKeys[0] = testDEK + + testData := []byte("something") + + // encrypt data out of band. + b, err := aes.NewCipher(testDEK) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipherText, nonce, 0) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + tests = append(tests, testCase{description: "test decrypt without arbitrary key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipherText, nonce, 0) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + tests = append(tests, testCase{description: "test decrypt with wrong data nonce should return error", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // generate random nonce. + randomNonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + // decrypted encrypted data using encrypted dek + _, err = m.Decrypt(cipherText, randomNonce, 0) + assert.NotNil(t, err) + }, + }) + + tests = append(tests, testCase{description: "test decrypt with DEK/nonce pair not used to encrypt should return error", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + key, id, err := m.newDataEncryptionKey() + require.Nil(t, err) + m.dataKeys[id] = key + + plainText, err := m.Decrypt(cipherText, nonce, id) + assert.NotNil(t, err) + assert.Nil(t, plainText) + }, + }) + tests = append(tests, testCase{description: "test decrypt for non active key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + cipher, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + // force key rotation. + m.activeKeyCounter += maxWriteCount + _, _, newKeyID, err := m.Encrypt(nil) + require.Nil(t, err) + require.NotEqual(t, keyID, newKeyID) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipher, nonce, keyID) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +var buf = make([]byte, 8192) + +func BenchmarkEncryption(b *testing.B) { + benchEncrypt(b, 1024) + benchEncrypt(b, 4096) + benchEncrypt(b, 8192) +} + +func BenchmarkDecryption(b *testing.B) { + benchDecrypt(b, 1024) + benchDecrypt(b, 4096) + benchDecrypt(b, 8192) +} + +func benchEncrypt(b *testing.B, size int) { + m, err := NewManager() + if err != nil { + b.Fatal("failed to create manager", err) + } + // disable auto rotation to avoid skewing results. + maxWriteCount = math.MaxInt32 + + b.Run(fmt.Sprintf("encrypt-%d", size), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + _, _, _, err := m.Encrypt(buf[:size]) + if err != nil { + b.Fatal("error encrypting data", err) + } + } + }) +} + +func benchDecrypt(b *testing.B, size int) { + m, err := NewManager() + if err != nil { + b.Fatal("failed to create manager", err) + } + + edata, enonce, kid, err := m.Encrypt(buf[:size]) + if err != nil { + b.Fatal("failed to encrypt data", err) + } + + b.Run(fmt.Sprintf("decrypt-%d", size), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + _, err := m.Decrypt(edata, enonce, kid) + if err != nil { + b.Fatal("error encrypting data", err) + } + } + }) +} diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go new file mode 100644 index 00000000..7d2c81ce --- /dev/null +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +// + +// Package informer is a generated GoMock package. +package informer + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} diff --git a/pkg/sqlcache/informer/dynamic_mocks_test.go b/pkg/sqlcache/informer/dynamic_mocks_test.go new file mode 100644 index 00000000..07e169c1 --- /dev/null +++ b/pkg/sqlcache/informer/dynamic_mocks_test.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" +) + +// MockResourceInterface is a mock of ResourceInterface interface. +type MockResourceInterface struct { + ctrl *gomock.Controller + recorder *MockResourceInterfaceMockRecorder +} + +// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface. +type MockResourceInterfaceMockRecorder struct { + mock *MockResourceInterface +} + +// NewMockResourceInterface creates a new mock instance. +func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface { + mock := &MockResourceInterface{ctrl: ctrl} + mock.recorder = &MockResourceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Apply", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...) +} + +// ApplyStatus mocks base method. +func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyStatus indicates an expected call of ApplyStatus. +func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3) +} + +// Create mocks base method. +func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...) +} + +// Delete mocks base method. +func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...) +} + +// DeleteCollection mocks base method. +func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...) +} + +// List mocks base method. +func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...) +} + +// UpdateStatus mocks base method. +func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateStatus indicates an expected call of UpdateStatus. +func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go new file mode 100644 index 00000000..9ac55bb3 --- /dev/null +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -0,0 +1,121 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient +// + +// Package factory is a generated GoMock package. +package factory + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} diff --git a/pkg/sqlcache/informer/factory/dynamic_mocks_test.go b/pkg/sqlcache/informer/factory/dynamic_mocks_test.go new file mode 100644 index 00000000..29e2c0fd --- /dev/null +++ b/pkg/sqlcache/informer/factory/dynamic_mocks_test.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +// + +// Package factory is a generated GoMock package. +package factory + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" +) + +// MockResourceInterface is a mock of ResourceInterface interface. +type MockResourceInterface struct { + ctrl *gomock.Controller + recorder *MockResourceInterfaceMockRecorder +} + +// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface. +type MockResourceInterfaceMockRecorder struct { + mock *MockResourceInterface +} + +// NewMockResourceInterface creates a new mock instance. +func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface { + mock := &MockResourceInterface{ctrl: ctrl} + mock.recorder = &MockResourceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Apply", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...) +} + +// ApplyStatus mocks base method. +func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyStatus indicates an expected call of ApplyStatus. +func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3) +} + +// Create mocks base method. +func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...) +} + +// Delete mocks base method. +func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...) +} + +// DeleteCollection mocks base method. +func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...) +} + +// List mocks base method. +func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...) +} + +// UpdateStatus mocks base method. +func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateStatus indicates an expected call of UpdateStatus. +func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/sqlcache/informer/factory/factory_mocks_test.go b/pkg/sqlcache/informer/factory/factory_mocks_test.go new file mode 100644 index 00000000..a7adab6a --- /dev/null +++ b/pkg/sqlcache/informer/factory/factory_mocks_test.go @@ -0,0 +1,179 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/informer/factory (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient +// + +// Package factory is a generated GoMock package. +package factory + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// NewConnection mocks base method. +func (m *MockDBClient) NewConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// NewConnection indicates an expected call of NewConnection. +func (mr *MockDBClientMockRecorder) NewConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockDBClient)(nil).NewConnection)) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go new file mode 100644 index 00000000..a8f3d266 --- /dev/null +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -0,0 +1,187 @@ +/* +Package factory provides a cache factory for the sql-based cache. +*/ +package factory + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/rancher/lasso/pkg/log" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/encryption" + "github.com/rancher/steve/pkg/sqlcache/informer" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +// EncryptAllEnvVar is set to "true" if users want all types' data blobs to be encrypted in SQLite +// otherwise only variables in defaultEncryptedResourceTypes will have their blobs encrypted +const EncryptAllEnvVar = "CATTLE_ENCRYPT_CACHE_ALL" + +// CacheFactory builds Informer instances and keeps a cache of instances it created +type CacheFactory struct { + wg wait.Group + dbClient DBClient + stopCh chan struct{} + mutex sync.RWMutex + encryptAll bool + + newInformer newInformer + + informers map[schema.GroupVersionKind]*guardedInformer + informersMutex sync.Mutex +} + +type guardedInformer struct { + informer *informer.Informer + mutex *sync.Mutex +} + +type newInformer func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespace bool) (*informer.Informer, error) + +type DBClient interface { + informer.DBClient + sqlStore.DBClient + connector +} + +type Cache struct { + informer.ByOptionsLister +} + +type connector interface { + NewConnection() error +} + +var defaultEncryptedResourceTypes = map[schema.GroupVersionKind]struct{}{ + { + Version: "v1", + Kind: "Secret", + }: {}, +} + +// NewCacheFactory returns an informer factory instance +// This is currently called from steve via initial calls to `s.cacheFactory.CacheFor(...)` +func NewCacheFactory() (*CacheFactory, error) { + m, err := encryption.NewManager() + if err != nil { + return nil, err + } + dbClient, err := db.NewClient(nil, m, m) + if err != nil { + return nil, err + } + return &CacheFactory{ + wg: wait.Group{}, + stopCh: make(chan struct{}), + encryptAll: os.Getenv(EncryptAllEnvVar) == "true", + dbClient: dbClient, + newInformer: informer.NewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + }, nil +} + +// CacheFor returns an informer for given GVK, using sql store indexed with fields, using the specified client. For virtual fields, they must be added by the transform function +// and specified by fields to be used for later fields. +func (f *CacheFactory) CacheFor(fields [][]string, transform cache.TransformFunc, client dynamic.ResourceInterface, gvk schema.GroupVersionKind, namespaced bool, watchable bool) (Cache, error) { + // First of all block Reset() until we are done + f.mutex.RLock() + defer f.mutex.RUnlock() + + // Second, check if the informer and its accompanying informer-specific mutex exist already in the informers cache + // If not, start by creating such informer-specific mutex. That is used later to ensure no two goroutines create + // informers for the same GVK at the same type + f.informersMutex.Lock() + // Note: the informers cache is protected by informersMutex, which we don't want to hold for very long because + // that blocks CacheFor for other GVKs, hence not deferring unlock here + gi, ok := f.informers[gvk] + if !ok { + gi = &guardedInformer{ + informer: nil, + mutex: &sync.Mutex{}, + } + f.informers[gvk] = gi + } + f.informersMutex.Unlock() + + // At this point an informer-specific mutex (gi.mutex) is guaranteed to exist. Lock it + gi.mutex.Lock() + defer gi.mutex.Unlock() + + // Then: if the informer really was not created yet (first time here or previous times have errored out) + // actually create the informer + if gi.informer == nil { + start := time.Now() + log.Debugf("CacheFor STARTS creating informer for %v", gvk) + defer func() { + log.Debugf("CacheFor IS DONE creating informer for %v (took %v)", gvk, time.Now().Sub(start)) + }() + + _, encryptResourceAlways := defaultEncryptedResourceTypes[gvk] + shouldEncrypt := f.encryptAll || encryptResourceAlways + i, err := f.newInformer(client, fields, transform, gvk, f.dbClient, shouldEncrypt, namespaced) + if err != nil { + return Cache{}, err + } + + err = i.SetWatchErrorHandler(func(r *cache.Reflector, err error) { + if !watchable && errors.IsMethodNotSupported(err) { + // expected, continue without logging + return + } + cache.DefaultWatchErrorHandler(r, err) + }) + if err != nil { + return Cache{}, err + } + + f.wg.StartWithChannel(f.stopCh, i.Run) + + gi.informer = i + } + + if !cache.WaitForCacheSync(f.stopCh, gi.informer.HasSynced) { + return Cache{}, fmt.Errorf("failed to sync SQLite Informer cache for GVK %v", gvk) + } + + // At this point the informer is ready, return it + return Cache{ByOptionsLister: gi.informer}, nil +} + +// Reset closes the stopCh which stops any running informers, assigns a new stopCh, resets the GVK-informer cache, and resets +// the database connection which wipes any current sqlite database at the default location. +func (f *CacheFactory) Reset() error { + if f.dbClient == nil { + // nothing to reset + return nil + } + + // first of all wait until all CacheFor() calls that create new informers are finished. Also block any new ones + f.mutex.Lock() + defer f.mutex.Unlock() + + // now that we are alone, stop all informers created until this point + close(f.stopCh) + f.stopCh = make(chan struct{}) + f.wg.Wait() + + // and get rid of all references to those informers and their mutexes + f.informersMutex.Lock() + defer f.informersMutex.Unlock() + f.informers = make(map[schema.GroupVersionKind]*guardedInformer) + + // finally, reset the DB connection + err := f.dbClient.NewConnection() + if err != nil { + return err + } + + return nil +} diff --git a/pkg/sqlcache/informer/factory/informer_factory_test.go b/pkg/sqlcache/informer/factory/informer_factory_test.go new file mode 100644 index 00000000..e3b96562 --- /dev/null +++ b/pkg/sqlcache/informer/factory/informer_factory_test.go @@ -0,0 +1,287 @@ +package factory + +import ( + "os" + "testing" + "time" + + "github.com/rancher/steve/pkg/sqlcache/informer" + + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer + +func TestNewCacheFactory(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned, should return no errors", test: func(t *testing.T) { + f, err := NewCacheFactory() + assert.Nil(t, err) + assert.NotNil(t, f.dbClient) + assert.False(t, f.encryptAll) + }}) + tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned and EncryptAllEnvVar set to true, should return no errors and have encryptAll set to true", test: func(t *testing.T) { + err := os.Setenv(EncryptAllEnvVar, "true") + assert.Nil(t, err) + f, err := NewCacheFactory() + assert.Nil(t, err) + assert.Nil(t, err) + assert.NotNil(t, f.dbClient) + assert.True(t, f.encryptAll) + }}) + // cannot run as parallel because tests involve changing env var + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestCacheFor(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh not closed, should return no error and should call Informer.Run(). A subsequent call to CacheFor() should return same informer", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true).AnyTimes() + sii.EXPECT().Run(gomock.Any()).MinTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + // this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually + time.Sleep(5 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + // this sleep is critical to the test. It ensure there has been enough time for expected function like Run to be invoked in their go routines. + time.Sleep(1 * time.Second) + c2, err := f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, c, c2) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning false, and stopCh not closed, should call Run() and return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(false).AnyTimes() + sii.EXPECT().Run(gomock.Any()) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + expectedI := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return expectedI, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + time.Sleep(1 * time.Second) + close(f.stopCh) + }() + var err error + _, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.NotNil(t, err) + time.Sleep(2 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh closed, should not call Run() more than once and not return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true).AnyTimes() + // may or may not call run initially + sii.EXPECT().Run(gomock.Any()).MaxTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + close(f.stopCh) + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned and encryptAll set to true, should return no error and pass shouldEncrypt as true to newInformer func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true) + sii.EXPECT().Run(gomock.Any()).MinTimes(1).AnyTimes() + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, true, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + encryptAll: true, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + time.Sleep(10 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, stopCh not closed, and transform func should return no error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true) + sii.EXPECT().Run(gomock.Any()).MinTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + // we can't test func == func, so instead we check if the output was as expected + input := "someinput" + ouput, err := transform(input) + assert.Nil(t, err) + outputStr, ok := ouput.(string) + assert.True(t, ok, "ouput from transform was expected to be a string") + assert.Equal(t, "someoutput", outputStr) + + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + // this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually + time.Sleep(5 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, transformFunc, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go b/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go new file mode 100644 index 00000000..b9c4dc35 --- /dev/null +++ b/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go @@ -0,0 +1,223 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/tools/cache (interfaces: SharedIndexInformer) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer +// + +// Package factory is a generated GoMock package. +package factory + +import ( + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" + cache "k8s.io/client-go/tools/cache" +) + +// MockSharedIndexInformer is a mock of SharedIndexInformer interface. +type MockSharedIndexInformer struct { + ctrl *gomock.Controller + recorder *MockSharedIndexInformerMockRecorder +} + +// MockSharedIndexInformerMockRecorder is the mock recorder for MockSharedIndexInformer. +type MockSharedIndexInformerMockRecorder struct { + mock *MockSharedIndexInformer +} + +// NewMockSharedIndexInformer creates a new mock instance. +func NewMockSharedIndexInformer(ctrl *gomock.Controller) *MockSharedIndexInformer { + mock := &MockSharedIndexInformer{ctrl: ctrl} + mock.recorder = &MockSharedIndexInformerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSharedIndexInformer) EXPECT() *MockSharedIndexInformerMockRecorder { + return m.recorder +} + +// AddEventHandler mocks base method. +func (m *MockSharedIndexInformer) AddEventHandler(arg0 cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEventHandler", arg0) + ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEventHandler indicates an expected call of AddEventHandler. +func (mr *MockSharedIndexInformerMockRecorder) AddEventHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandler), arg0) +} + +// AddEventHandlerWithResyncPeriod mocks base method. +func (m *MockSharedIndexInformer) AddEventHandlerWithResyncPeriod(arg0 cache.ResourceEventHandler, arg1 time.Duration) (cache.ResourceEventHandlerRegistration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEventHandlerWithResyncPeriod", arg0, arg1) + ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEventHandlerWithResyncPeriod indicates an expected call of AddEventHandlerWithResyncPeriod. +func (mr *MockSharedIndexInformerMockRecorder) AddEventHandlerWithResyncPeriod(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandlerWithResyncPeriod", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandlerWithResyncPeriod), arg0, arg1) +} + +// AddIndexers mocks base method. +func (m *MockSharedIndexInformer) AddIndexers(arg0 cache.Indexers) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddIndexers", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddIndexers indicates an expected call of AddIndexers. +func (mr *MockSharedIndexInformerMockRecorder) AddIndexers(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIndexers", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddIndexers), arg0) +} + +// GetController mocks base method. +func (m *MockSharedIndexInformer) GetController() cache.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetController") + ret0, _ := ret[0].(cache.Controller) + return ret0 +} + +// GetController indicates an expected call of GetController. +func (mr *MockSharedIndexInformerMockRecorder) GetController() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetController", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetController)) +} + +// GetIndexer mocks base method. +func (m *MockSharedIndexInformer) GetIndexer() cache.Indexer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIndexer") + ret0, _ := ret[0].(cache.Indexer) + return ret0 +} + +// GetIndexer indicates an expected call of GetIndexer. +func (mr *MockSharedIndexInformerMockRecorder) GetIndexer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndexer", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetIndexer)) +} + +// GetStore mocks base method. +func (m *MockSharedIndexInformer) GetStore() cache.Store { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStore") + ret0, _ := ret[0].(cache.Store) + return ret0 +} + +// GetStore indicates an expected call of GetStore. +func (mr *MockSharedIndexInformerMockRecorder) GetStore() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetStore)) +} + +// HasSynced mocks base method. +func (m *MockSharedIndexInformer) HasSynced() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasSynced") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasSynced indicates an expected call of HasSynced. +func (mr *MockSharedIndexInformerMockRecorder) HasSynced() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSynced", reflect.TypeOf((*MockSharedIndexInformer)(nil).HasSynced)) +} + +// IsStopped mocks base method. +func (m *MockSharedIndexInformer) IsStopped() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsStopped") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsStopped indicates an expected call of IsStopped. +func (mr *MockSharedIndexInformerMockRecorder) IsStopped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStopped", reflect.TypeOf((*MockSharedIndexInformer)(nil).IsStopped)) +} + +// LastSyncResourceVersion mocks base method. +func (m *MockSharedIndexInformer) LastSyncResourceVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastSyncResourceVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// LastSyncResourceVersion indicates an expected call of LastSyncResourceVersion. +func (mr *MockSharedIndexInformerMockRecorder) LastSyncResourceVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastSyncResourceVersion", reflect.TypeOf((*MockSharedIndexInformer)(nil).LastSyncResourceVersion)) +} + +// RemoveEventHandler mocks base method. +func (m *MockSharedIndexInformer) RemoveEventHandler(arg0 cache.ResourceEventHandlerRegistration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveEventHandler", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveEventHandler indicates an expected call of RemoveEventHandler. +func (mr *MockSharedIndexInformerMockRecorder) RemoveEventHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).RemoveEventHandler), arg0) +} + +// Run mocks base method. +func (m *MockSharedIndexInformer) Run(arg0 <-chan struct{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Run", arg0) +} + +// Run indicates an expected call of Run. +func (mr *MockSharedIndexInformerMockRecorder) Run(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSharedIndexInformer)(nil).Run), arg0) +} + +// SetTransform mocks base method. +func (m *MockSharedIndexInformer) SetTransform(arg0 cache.TransformFunc) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetTransform", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetTransform indicates an expected call of SetTransform. +func (mr *MockSharedIndexInformerMockRecorder) SetTransform(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTransform", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetTransform), arg0) +} + +// SetWatchErrorHandler mocks base method. +func (m *MockSharedIndexInformer) SetWatchErrorHandler(arg0 cache.WatchErrorHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWatchErrorHandler", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWatchErrorHandler indicates an expected call of SetWatchErrorHandler. +func (mr *MockSharedIndexInformerMockRecorder) SetWatchErrorHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWatchErrorHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetWatchErrorHandler), arg0) +} diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go new file mode 100644 index 00000000..7ed4451b --- /dev/null +++ b/pkg/sqlcache/informer/indexer.go @@ -0,0 +1,264 @@ +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" + "k8s.io/client-go/tools/cache" +) + +const ( + selectQueryFmt = ` + SELECT object, objectnonce, dekid FROM "%[1]s" + WHERE key IN ( + SELECT key FROM "%[1]s_indices" + WHERE name = ? AND value IN (?%s) + ) + ` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%[1]s_indices" ( + name TEXT NOT NULL, + value TEXT NOT NULL, + key TEXT NOT NULL REFERENCES "%[1]s"(key) ON DELETE CASCADE, + PRIMARY KEY (name, value, key) + )` + createIndexFmt = `CREATE INDEX IF NOT EXISTS "%[1]s_indices_index" ON "%[1]s_indices"(name, value)` + + deleteIndicesFmt = `DELETE FROM "%s_indices" WHERE key = ?` + addIndexFmt = `INSERT INTO "%s_indices" (name, value, key) VALUES (?, ?, ?) ON CONFLICT DO NOTHING` + listByIndexFmt = `SELECT object, objectnonce, dekid FROM "%[1]s" + WHERE key IN ( + SELECT key FROM "%[1]s_indices" + WHERE name = ? AND value = ? + )` + listKeyByIndexFmt = `SELECT DISTINCT key FROM "%s_indices" WHERE name = ? AND value = ?` + listIndexValuesFmt = `SELECT DISTINCT value FROM "%s_indices" WHERE name = ?` +) + +// Indexer is a SQLite-backed cache.Indexer which builds upon Store adding an index table +type Indexer struct { + Store + indexers cache.Indexers + indexersLock sync.RWMutex + + deleteIndicesQuery string + addIndexQuery string + listByIndexQuery string + listKeysByIndexQuery string + listIndexValuesQuery string + + deleteIndicesStmt *sql.Stmt + addIndexStmt *sql.Stmt + listByIndexStmt *sql.Stmt + listKeysByIndexStmt *sql.Stmt + listIndexValuesStmt *sql.Stmt +} + +var _ cache.Indexer = (*Indexer)(nil) + +type Store interface { + DBClient + cache.Store + + GetByKey(key string) (item any, exists bool, err error) + GetName() string + RegisterAfterUpsert(f func(key string, obj any, tx db.TXClient) error) + RegisterAfterDelete(f func(key string, tx db.TXClient) error) + GetShouldEncrypt() bool + GetType() reflect.Type +} + +type DBClient interface { + BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows db.Rows) ([]string, error) + ReadInt(rows db.Rows) (int, error) + Prepare(stmt string) *sql.Stmt + CloseStmt(stmt db.Closable) error +} + +// NewIndexer returns a cache.Indexer backed by SQLite for objects of the given example type +func NewIndexer(indexers cache.Indexers, s Store) (*Indexer, error) { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + dbName := db.Sanitize(s.GetName()) + createTableQuery := fmt.Sprintf(createTableFmt, dbName) + err = tx.Exec(createTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createTableQuery, Err: err} + } + createIndexQuery := fmt.Sprintf(createIndexFmt, dbName) + err = tx.Exec(createIndexQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createIndexQuery, Err: err} + } + err = tx.Commit() + if err != nil { + return nil, err + } + + i := &Indexer{ + Store: s, + indexers: indexers, + } + i.RegisterAfterUpsert(i.AfterUpsert) + + i.deleteIndicesQuery = fmt.Sprintf(deleteIndicesFmt, db.Sanitize(s.GetName())) + i.addIndexQuery = fmt.Sprintf(addIndexFmt, db.Sanitize(s.GetName())) + i.listByIndexQuery = fmt.Sprintf(listByIndexFmt, db.Sanitize(s.GetName())) + i.listKeysByIndexQuery = fmt.Sprintf(listKeyByIndexFmt, db.Sanitize(s.GetName())) + i.listIndexValuesQuery = fmt.Sprintf(listIndexValuesFmt, db.Sanitize(s.GetName())) + + i.deleteIndicesStmt = s.Prepare(i.deleteIndicesQuery) + i.addIndexStmt = s.Prepare(i.addIndexQuery) + i.listByIndexStmt = s.Prepare(i.listByIndexQuery) + i.listKeysByIndexStmt = s.Prepare(i.listKeysByIndexQuery) + i.listIndexValuesStmt = s.Prepare(i.listIndexValuesQuery) + + return i, nil +} + +/* Core methods */ + +// AfterUpsert updates indices of an object +func (i *Indexer) AfterUpsert(key string, obj any, tx db.TXClient) error { + // delete all + err := tx.StmtExec(tx.Stmt(i.deleteIndicesStmt), key) + if err != nil { + return &db.QueryError{QueryString: i.deleteIndicesQuery, Err: err} + } + + // re-insert all values + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + for indexName, indexFunc := range i.indexers { + values, err := indexFunc(obj) + if err != nil { + return err + } + + for _, value := range values { + err = tx.StmtExec(tx.Stmt(i.addIndexStmt), indexName, value, key) + if err != nil { + return &db.QueryError{QueryString: i.addIndexQuery, Err: err} + } + } + } + return nil +} + +/* Satisfy cache.Indexer */ + +// Index returns a list of items that match the given object on the index function +func (i *Indexer) Index(indexName string, obj any) ([]any, error) { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + indexFunc := i.indexers[indexName] + if indexFunc == nil { + return nil, fmt.Errorf("index with name %s does not exist", indexName) + } + + values, err := indexFunc(obj) + if err != nil { + return nil, err + } + + if len(values) == 0 { + return nil, nil + } + + // typical case + if len(values) == 1 { + return i.ByIndex(indexName, values[0]) + } + + // atypical case - more than one value to lookup + // HACK: sql.Statement.Query does not allow to pass slices in as of go 1.19 - create an ad-hoc statement + query := fmt.Sprintf(selectQueryFmt, db.Sanitize(i.GetName()), strings.Repeat(", ?", len(values)-1)) + stmt := i.Prepare(query) + defer i.CloseStmt(stmt) + // HACK: Query will accept []any but not []string + params := []any{indexName} + for _, value := range values { + params = append(params, value) + } + + rows, err := i.QueryForRows(context.TODO(), stmt, params...) + if err != nil { + return nil, &db.QueryError{QueryString: query, Err: err} + } + return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt()) +} + +// ByIndex returns the stored objects whose set of indexed values +// for the named index includes the given indexed value +func (i *Indexer) ByIndex(indexName, indexedValue string) ([]any, error) { + rows, err := i.QueryForRows(context.TODO(), i.listByIndexStmt, indexName, indexedValue) + if err != nil { + return nil, &db.QueryError{QueryString: i.listByIndexQuery, Err: err} + } + return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt()) +} + +// IndexKeys returns a list of the Store keys of the objects whose indexed values in the given index include the given indexed value +func (i *Indexer) IndexKeys(indexName, indexedValue string) ([]string, error) { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + indexFunc := i.indexers[indexName] + if indexFunc == nil { + return nil, fmt.Errorf("Index with name %s does not exist", indexName) + } + + rows, err := i.QueryForRows(context.TODO(), i.listKeysByIndexStmt, indexName, indexedValue) + if err != nil { + return nil, &db.QueryError{QueryString: i.listKeysByIndexQuery, Err: err} + } + return i.ReadStrings(rows) +} + +// ListIndexFuncValues wraps safeListIndexFuncValues and panics in case of I/O errors +func (i *Indexer) ListIndexFuncValues(name string) []string { + result, err := i.safeListIndexFuncValues(name) + if err != nil { + panic(fmt.Errorf("unexpected error in safeListIndexFuncValues: %w", err)) + } + return result +} + +// safeListIndexFuncValues returns all the indexed values of the given index +func (i *Indexer) safeListIndexFuncValues(indexName string) ([]string, error) { + rows, err := i.QueryForRows(context.TODO(), i.listIndexValuesStmt, indexName) + if err != nil { + return nil, &db.QueryError{QueryString: i.listIndexValuesQuery, Err: err} + } + return i.ReadStrings(rows) +} + +// GetIndexers returns the indexers +func (i *Indexer) GetIndexers() cache.Indexers { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + return i.indexers +} + +// AddIndexers adds more indexers to this Store. If you call this after you already have data +// in the Store, the results are undefined. +func (i *Indexer) AddIndexers(newIndexers cache.Indexers) error { + i.indexersLock.Lock() + defer i.indexersLock.Unlock() + if i.indexers == nil { + i.indexers = make(map[string]cache.IndexFunc) + } + for k, v := range newIndexers { + i.indexers[k] = v + } + return nil +} diff --git a/pkg/sqlcache/informer/indexer_test.go b/pkg/sqlcache/informer/indexer_test.go new file mode 100644 index 00000000..4118118c --- /dev/null +++ b/pkg/sqlcache/informer/indexer_test.go @@ -0,0 +1,614 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows + +type testStoreObject struct { + Id string + Val string +} + +func TestNewIndexer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewIndexer() with no errors returned from Store or TXClient, should return no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) + client.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(fmt.Sprintf(deleteIndicesFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(addIndexFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listByIndexFmt, storeName, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listKeyByIndexFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listIndexValuesFmt, storeName)) + indexer, err := NewIndexer(indexers, store) + assert.Nil(t, err) + assert.Equal(t, cache.Indexers(indexers), indexer.indexers) + }}) + tests = append(tests, testCase{description: "NewIndexer() with Store Begin() error, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + store.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on first call to Exec(), should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on second call to Exec(), should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Commit() error, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) + client.EXPECT().Commit().Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestAfterUpsert(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "AfterUpsert() with no errors returned from TXClient should return no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) + client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(nil) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient StmtExec() should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(fmt.Errorf("error")) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient second StmtExec() call should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) + client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(fmt.Errorf("error")) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestIndex(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Index() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject, testObject}, objs) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{}, objs) + }}) + tests = append(tests, testCase{description: "Index() where index name is not in indexers, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + _, err := indexer.Index("someotherindexname", testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error")) + _, err := indexer.Index(indexName, testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error")) + _, err := indexer.Index(indexName, testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple keys returned from index func, should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey, objKey + "2"}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().GetName().Return("name") + stmt := &sql.Stmt{} + store.EXPECT().Prepare(fmt.Sprintf(selectQueryFmt, "name", ", ?")).Return(stmt) + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey, objKey+"2").Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestByIndex(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{testObject, testObject}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error")) + _, err := indexer.ByIndex(indexName, objKey) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "IndexBy() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error")) + _, err := indexer.ByIndex(indexName, objKey) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestListIndexFuncValues(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListIndexFuncvalues() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil) + store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, nil) + vals := indexer.ListIndexFuncValues(indexName) + assert.Equal(t, []string{"somestrings"}, vals) + }}) + tests = append(tests, testCase{description: "ListIndexFuncvalues() with QueryForRows() error returned from store, should panic", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(nil, fmt.Errorf("error")) + assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) }) + }}) + tests = append(tests, testCase{description: "ListIndexFuncvalues() with ReadStrings() error returned from store, should panic", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil) + store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, fmt.Errorf("error")) + assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) }) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestGetIndexers(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) { + objKey := "key" + expectedIndexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + indexer := &Indexer{ + indexers: expectedIndexers, + } + indexers := indexer.GetIndexers() + assert.Equal(t, cache.Indexers(expectedIndexers), indexers) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestAddIndexers(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) { + objKey := "key" + expectedIndexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + indexer := &Indexer{} + err := indexer.AddIndexers(expectedIndexers) + assert.Nil(t, err) + assert.ObjectsAreEqual(cache.Indexers(expectedIndexers), indexer.indexers) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/informer/informer.go b/pkg/sqlcache/informer/informer.go new file mode 100644 index 00000000..a74c7029 --- /dev/null +++ b/pkg/sqlcache/informer/informer.go @@ -0,0 +1,94 @@ +/* +package sql provides an Informer and Indexer that uses SQLite as a store, instead of an in-memory store like a map. +*/ + +package informer + +import ( + "context" + "time" + + "github.com/rancher/steve/pkg/sqlcache/partition" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +// Informer is a SQLite-backed cache.SharedIndexInformer that can execute queries on listprocessor structs +type Informer struct { + cache.SharedIndexInformer + ByOptionsLister +} + +type ByOptionsLister interface { + ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) +} + +// this is set to a var so that it can be overridden by test code for mocking purposes +var newInformer = cache.NewSharedIndexInformer + +// NewInformer returns a new SQLite-backed Informer for the type specified by schema in unstructured.Unstructured form +// using the specified client +func NewInformer(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*Informer, error) { + listWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + a, err := client.List(context.Background(), options) + return a, err + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + return client.Watch(context.Background(), options) + }, + } + + example := &unstructured.Unstructured{} + example.SetGroupVersionKind(gvk) + + // avoids the informer to periodically resync (re-list) its resources + // currently it is a work hypothesis that, when interacting with the UI, this should not be needed + resyncPeriod := time.Duration(0) + + sii := newInformer(listWatcher, example, resyncPeriod, cache.Indexers{}) + if transform != nil { + if err := sii.SetTransform(transform); err != nil { + return nil, err + } + } + + name := informerNameFromGVK(gvk) + + s, err := sqlStore.NewStore(example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, shouldEncrypt, name) + if err != nil { + return nil, err + } + loi, err := NewListOptionIndexer(fields, s, namespaced) + if err != nil { + return nil, err + } + + // HACK: replace the default informer's indexer with the SQL based one + UnsafeSet(sii, "indexer", loi) + + return &Informer{ + SharedIndexInformer: sii, + ByOptionsLister: loi, + }, nil +} + +// ListByOptions returns objects according to the specified list options and partitions. +// Specifically: +// - an unstructured list of resources belonging to any of the specified partitions +// - the total number of resources (returned list might be a subset depending on pagination options in lo) +// - a continue token, if there are more pages after the returned one +// - an error instead of all of the above if anything went wrong +func (i *Informer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { + return i.ByOptionsLister.ListByOptions(ctx, lo, partitions, namespace) +} + +func informerNameFromGVK(gvk schema.GroupVersionKind) string { + return gvk.Group + "_" + gvk.Version + "_" + gvk.Kind +} diff --git a/pkg/sqlcache/informer/informer_mocks_test.go b/pkg/sqlcache/informer/informer_mocks_test.go new file mode 100644 index 00000000..9eff0612 --- /dev/null +++ b/pkg/sqlcache/informer/informer_mocks_test.go @@ -0,0 +1,59 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + reflect "reflect" + + partition "github.com/rancher/steve/pkg/sqlcache/partition" + gomock "go.uber.org/mock/gomock" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +// MockByOptionsLister is a mock of ByOptionsLister interface. +type MockByOptionsLister struct { + ctrl *gomock.Controller + recorder *MockByOptionsListerMockRecorder +} + +// MockByOptionsListerMockRecorder is the mock recorder for MockByOptionsLister. +type MockByOptionsListerMockRecorder struct { + mock *MockByOptionsLister +} + +// NewMockByOptionsLister creates a new mock instance. +func NewMockByOptionsLister(ctrl *gomock.Controller) *MockByOptionsLister { + mock := &MockByOptionsLister{ctrl: ctrl} + mock.recorder = &MockByOptionsListerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockByOptionsLister) EXPECT() *MockByOptionsListerMockRecorder { + return m.recorder +} + +// ListByOptions mocks base method. +func (m *MockByOptionsLister) ListByOptions(arg0 context.Context, arg1 ListOptions, arg2 []partition.Partition, arg3 string) (*unstructured.UnstructuredList, int, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByOptions", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(string) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// ListByOptions indicates an expected call of ListByOptions. +func (mr *MockByOptionsListerMockRecorder) ListByOptions(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByOptions", reflect.TypeOf((*MockByOptionsLister)(nil).ListByOptions), arg0, arg1, arg2, arg3) +} diff --git a/pkg/sqlcache/informer/informer_test.go b/pkg/sqlcache/informer/informer_test.go new file mode 100644 index 00000000..5337ee8e --- /dev/null +++ b/pkg/sqlcache/informer/informer_test.go @@ -0,0 +1,351 @@ +package informer + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/rancher/steve/pkg/sqlcache/partition" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient + +func TestNewInformer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewInformer() with no errors returned, should return no error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(context.Background(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + informer, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.Nil(t, err) + assert.NotNil(t, informer.ByOptionsLister) + assert.NotNil(t, informer.SharedIndexInformer) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewStore(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewIndexer(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewListOptionIndexer(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with transform func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + mockInformer := mockInformer{} + testNewInformer := func(lw cache.ListerWatcher, + exampleObject runtime.Object, + defaultEventHandlerResyncPeriod time.Duration, + indexers cache.Indexers) cache.SharedIndexInformer { + return &mockInformer + } + newInformer = testNewInformer + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + informer, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true) + assert.Nil(t, err) + assert.NotNil(t, informer.ByOptionsLister) + assert.NotNil(t, informer.SharedIndexInformer) + assert.NotNil(t, mockInformer.transformFunc) + + // we can't test func == func, so instead we check if the output was as expected + input := "someinput" + ouput, err := mockInformer.transformFunc(input) + assert.Nil(t, err) + outputStr, ok := ouput.(string) + assert.True(t, ok, "ouput from transform was expected to be a string") + assert.Equal(t, "someoutput", outputStr) + + newInformer = cache.NewSharedIndexInformer + }}) + tests = append(tests, testCase{description: "NewInformer() unable to set transform func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + mockInformer := mockInformer{ + setTranformErr: fmt.Errorf("some error"), + } + testNewInformer := func(lw cache.ListerWatcher, + exampleObject runtime.Object, + defaultEventHandlerResyncPeriod time.Duration, + indexers cache.Indexers) cache.SharedIndexInformer { + return &mockInformer + } + newInformer = testNewInformer + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + _, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true) + assert.Error(t, err) + newInformer = cache.NewSharedIndexInformer + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestInformerListByOptions(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListByOptions() with no errors returned, should return no error and return value from indexer's ListByOptions()", test: func(t *testing.T) { + indexer := NewMockByOptionsLister(gomock.NewController(t)) + informer := &Informer{ + ByOptionsLister: indexer, + } + lo := ListOptions{} + var partitions []partition.Partition + ns := "somens" + expectedList := &unstructured.UnstructuredList{ + Object: map[string]interface{}{"s": 2}, + Items: []unstructured.Unstructured{{ + Object: map[string]interface{}{"s": 2}, + }}, + } + expectedTotal := len(expectedList.Items) + expectedContinueToken := "123" + indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(expectedList, expectedTotal, expectedContinueToken, nil) + list, total, continueToken, err := informer.ListByOptions(context.TODO(), lo, partitions, ns) + assert.Nil(t, err) + assert.Equal(t, expectedList, list) + assert.Equal(t, len(expectedList.Items), total) + assert.Equal(t, expectedContinueToken, continueToken) + }}) + tests = append(tests, testCase{description: "ListByOptions() with indexer ListByOptions error, should return error", test: func(t *testing.T) { + indexer := NewMockByOptionsLister(gomock.NewController(t)) + informer := &Informer{ + ByOptionsLister: indexer, + } + lo := ListOptions{} + var partitions []partition.Partition + ns := "somens" + indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(nil, 0, "", fmt.Errorf("error")) + _, _, _, err := informer.ListByOptions(context.TODO(), lo, partitions, ns) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +// Note: SQLite based caching uses an Informer that unsafely sets the Indexer as the ability to set it is not present +// in client-go at the moment. Long term, we look forward contribute a patch to client-go to make that configurable. +// Until then, we are adding this canary test that will panic in case the indexer cannot be set. +func TestUnsafeSet(t *testing.T) { + listWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + return &unstructured.UnstructuredList{}, nil + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + return dummyWatch{}, nil + }, + } + + sii := cache.NewSharedIndexInformer(listWatcher, &unstructured.Unstructured{}, 0, cache.Indexers{}) + + // will panic if SharedIndexInformer stops having a *Indexer field called "indexer" + UnsafeSet(sii, "indexer", &Indexer{}) +} + +type dummyWatch struct{} + +func (dummyWatch) Stop() { +} + +func (dummyWatch) ResultChan() <-chan watch.Event { + result := make(chan watch.Event) + defer close(result) + return result +} + +// mockInformer is a mock of cache.SharedIndexInformer. Unlike other types, we can't generate this using mockgen because we use a unsafeSet to replace the +// indexer field, which is a struct field. This won't exist on the mock, producing an error. So we need to implement our own mock which actually has this field. +type mockInformer struct { + transformFunc cache.TransformFunc + setTranformErr error + indexer cache.Indexer +} + +func (m *mockInformer) AddEventHandler(handler cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + return nil, nil +} +func (m *mockInformer) AddEventHandlerWithResyncPeriod(handler cache.ResourceEventHandler, resyncPeriod time.Duration) (cache.ResourceEventHandlerRegistration, error) { + return nil, nil +} +func (m *mockInformer) RemoveEventHandler(handle cache.ResourceEventHandlerRegistration) error { + return nil +} +func (m *mockInformer) GetStore() cache.Store { return nil } +func (m *mockInformer) GetController() cache.Controller { return nil } +func (m *mockInformer) Run(stopCh <-chan struct{}) {} +func (m *mockInformer) HasSynced() bool { return false } +func (m *mockInformer) LastSyncResourceVersion() string { return "" } +func (m *mockInformer) SetWatchErrorHandler(handler cache.WatchErrorHandler) error { return nil } +func (m *mockInformer) IsStopped() bool { return false } +func (m *mockInformer) AddIndexers(indexers cache.Indexers) error { return nil } +func (m *mockInformer) GetIndexer() cache.Indexer { return nil } +func (m *mockInformer) SetTransform(handler cache.TransformFunc) error { + m.transformFunc = handler + return m.setTranformErr +} diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go new file mode 100644 index 00000000..93ce5f08 --- /dev/null +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -0,0 +1,853 @@ +package informer + +import ( + "context" + "database/sql" + "encoding/gob" + "errors" + "fmt" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/client-go/tools/cache" + + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/partition" +) + +// ListOptionIndexer extends Indexer by allowing queries based on ListOption +type ListOptionIndexer struct { + *Indexer + + namespaced bool + indexedFields []string + + addFieldQuery string + deleteFieldQuery string + upsertLabelsQuery string + deleteLabelsQuery string + + addFieldStmt *sql.Stmt + deleteFieldStmt *sql.Stmt + upsertLabelsStmt *sql.Stmt + deleteLabelsStmt *sql.Stmt +} + +var ( + defaultIndexedFields = []string{"metadata.name", "metadata.creationTimestamp"} + defaultIndexNamespaced = "metadata.namespace" + subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[a-zA-Z./]+])|(\[[0-9]+])`) + + ErrInvalidColumn = errors.New("supplied column is invalid") +) + +const ( + matchFmt = `%%%s%%` + strictMatchFmt = `%s` + escapeBackslashDirective = ` ESCAPE '\'` // The leading space is crucial for unit tests only ' + createFieldsTableFmt = `CREATE TABLE "%s_fields" ( + key TEXT NOT NULL PRIMARY KEY, + %s + )` + createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")` + + failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items: %w" + + createLabelsTableFmt = `CREATE TABLE IF NOT EXISTS "%s_labels" ( + key TEXT NOT NULL REFERENCES "%s"(key) ON DELETE CASCADE, + label TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (key, label) + )` + createLabelsTableIndexFmt = `CREATE INDEX IF NOT EXISTS "%s_labels_index" ON "%s_labels"(label, value)` + + upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)` + deleteLabelsStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?` +) + +// NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK +// ListOptionIndexer is also able to satisfy ListOption queries on indexed (sub)fields. +// Fields are specified as slices (e.g. "metadata.resourceVersion" is ["metadata", "resourceVersion"]) +func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOptionIndexer, error) { + // necessary in order to gob/ungob unstructured.Unstructured objects + gob.Register(map[string]interface{}{}) + gob.Register([]interface{}{}) + + i, err := NewIndexer(cache.Indexers{}, s) + if err != nil { + return nil, err + } + + var indexedFields []string + for _, f := range defaultIndexedFields { + indexedFields = append(indexedFields, f) + } + if namespaced { + indexedFields = append(indexedFields, defaultIndexNamespaced) + } + for _, f := range fields { + indexedFields = append(indexedFields, toColumnName(f)) + } + + l := &ListOptionIndexer{ + Indexer: i, + namespaced: namespaced, + indexedFields: indexedFields, + } + l.RegisterAfterUpsert(l.addIndexFields) + l.RegisterAfterUpsert(l.addLabels) + l.RegisterAfterDelete(l.deleteIndexFields) + l.RegisterAfterDelete(l.deleteLabels) + columnDefs := make([]string, len(indexedFields)) + for index, field := range indexedFields { + column := fmt.Sprintf(`"%s" TEXT`, field) + columnDefs[index] = column + } + + tx, err := l.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + dbName := db.Sanitize(i.GetName()) + err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", "))) + if err != nil { + return nil, err + } + + columns := make([]string, len(indexedFields)) + qmarks := make([]string, len(indexedFields)) + setStatements := make([]string, len(indexedFields)) + + for index, field := range indexedFields { + // create index for field + err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field)) + if err != nil { + return nil, err + } + + // format field into column for prepared statement + column := fmt.Sprintf(`"%s"`, field) + columns[index] = column + + // add placeholder for column's value in prepared statement + qmarks[index] = "?" + + // add formatted set statement for prepared statement + setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field) + setStatements[index] = setStatement + } + createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName) + err = tx.Exec(createLabelsTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createLabelsTableQuery, Err: err} + } + + createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName) + err = tx.Exec(createLabelsTableIndexQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err} + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + l.addFieldQuery = fmt.Sprintf( + `INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`, + dbName, + strings.Join(columns, ", "), + strings.Join(qmarks, ", "), + strings.Join(setStatements, ", "), + ) + l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName) + + l.addFieldStmt = l.Prepare(l.addFieldQuery) + l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery) + + l.upsertLabelsQuery = fmt.Sprintf(upsertLabelsStmtFmt, dbName) + l.deleteLabelsQuery = fmt.Sprintf(deleteLabelsStmtFmt, dbName) + l.upsertLabelsStmt = l.Prepare(l.upsertLabelsQuery) + l.deleteLabelsStmt = l.Prepare(l.deleteLabelsQuery) + + return l, nil +} + +/* Core methods */ + +// addIndexFields saves sortable/filterable fields into tables +func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) error { + args := []any{key} + for _, field := range l.indexedFields { + value, err := getField(obj, field) + if err != nil { + logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err) + cErr := tx.Cancel() + if cErr != nil { + return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err) + } + return err + } + switch typedValue := value.(type) { + case nil: + args = append(args, "") + case int, bool, string: + args = append(args, fmt.Sprint(typedValue)) + case []string: + args = append(args, strings.Join(typedValue, "|")) + default: + err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value) + cErr := tx.Cancel() + if cErr != nil { + return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2) + } + return err2 + } + } + + err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...) + if err != nil { + return &db.QueryError{QueryString: l.addFieldQuery, Err: err} + } + return nil +} + +// labels are stored in tables that shadow the underlying object table for each GVK +func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error { + k8sObj, ok := obj.(*unstructured.Unstructured) + if !ok { + return fmt.Errorf("addLabels: unexpected object type, expected unstructured.Unstructured: %v", obj) + } + incomingLabels := k8sObj.GetLabels() + for k, v := range incomingLabels { + err := tx.StmtExec(tx.Stmt(l.upsertLabelsStmt), key, k, v) + if err != nil { + return &db.QueryError{QueryString: l.upsertLabelsQuery, Err: err} + } + } + return nil +} + +func (l *ListOptionIndexer) deleteIndexFields(key string, tx db.TXClient) error { + args := []any{key} + + err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...) + if err != nil { + return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err} + } + return nil +} + +func (l *ListOptionIndexer) deleteLabels(key string, tx db.TXClient) error { + err := tx.StmtExec(tx.Stmt(l.deleteLabelsStmt), key) + if err != nil { + return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err} + } + return nil +} + +// ListByOptions returns objects according to the specified list options and partitions. +// Specifically: +// - an unstructured list of resources belonging to any of the specified partitions +// - the total number of resources (returned list might be a subset depending on pagination options in lo) +// - a continue token, if there are more pages after the returned one +// - an error instead of all of the above if anything went wrong +func (l *ListOptionIndexer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { + queryInfo, err := l.constructQuery(lo, partitions, namespace, db.Sanitize(l.GetName())) + if err != nil { + return nil, 0, "", err + } + return l.executeQuery(ctx, queryInfo) +} + +// QueryInfo is a helper-struct that is used to represent the core query and parameters when converting +// a filter from the UI into a sql query +type QueryInfo struct { + query string + params []any + countQuery string + countParams []any + limit int + offset int +} + +func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partition.Partition, namespace string, dbName string) (*QueryInfo, error) { + queryInfo := &QueryInfo{} + queryHasLabelFilter := hasLabelFilter(lo.Filters) + + // First, what kind of filtering will we be doing? + // 1- Intro: SELECT and JOIN clauses + // There's a 1:1 correspondence between a base table and its _Fields table + // but it's possible that a key has no associated labels, so if we're doing a + // non-existence test on labels we need to do a LEFT OUTER JOIN + distinctModifier := "" + if queryHasLabelFilter { + distinctModifier = " DISTINCT" + } + query := fmt.Sprintf(`SELECT%s o.object, o.objectnonce, o.dekid FROM "%s" o`, distinctModifier, dbName) + query += "\n " + query += fmt.Sprintf(`JOIN "%s_fields" f ON o.key = f.key`, dbName) + if queryHasLabelFilter { + for i, orFilters := range lo.Filters { + if hasLabelFilter([]OrFilter{orFilters}) { + query += "\n " + // Make the lt index 1-based for readability + query += fmt.Sprintf(`LEFT OUTER JOIN "%s_labels" lt%d ON o.key = lt%d.key`, dbName, i+1, i+1) + } + } + } + params := []any{} + + // 2- Filtering: WHERE clauses (from lo.Filters) + whereClauses := []string{} + for i, orFilters := range lo.Filters { + orClause, orParams, err := l.buildORClauseFromFilters(i+1, orFilters, dbName) + if err != nil { + return queryInfo, err + } + if orClause == "" { + continue + } + whereClauses = append(whereClauses, orClause) + params = append(params, orParams...) + } + + // WHERE clauses (from namespace) + if namespace != "" && namespace != "*" { + whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) + params = append(params, namespace) + } + + // WHERE clauses (from partitions and their corresponding parameters) + partitionClauses := []string{} + for _, thisPartition := range partitions { + if thisPartition.Passthrough { + // nothing to do, no extra filtering to apply by definition + } else { + singlePartitionClauses := []string{} + + // filter by namespace + if thisPartition.Namespace != "" && thisPartition.Namespace != "*" { + singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) + params = append(params, thisPartition.Namespace) + } + + // optionally filter by names + if !thisPartition.All { + names := thisPartition.Names + + if len(names) == 0 { + // degenerate case, there will be no results + singlePartitionClauses = append(singlePartitionClauses, "FALSE") + } else { + singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.name" IN (?%s)`, strings.Repeat(", ?", len(thisPartition.Names)-1))) + // sort for reproducibility + sortedNames := thisPartition.Names.UnsortedList() + sort.Strings(sortedNames) + for _, name := range sortedNames { + params = append(params, name) + } + } + } + + if len(singlePartitionClauses) > 0 { + partitionClauses = append(partitionClauses, strings.Join(singlePartitionClauses, " AND ")) + } + } + } + if len(partitions) == 0 { + // degenerate case, there will be no results + whereClauses = append(whereClauses, "FALSE") + } + if len(partitionClauses) == 1 { + whereClauses = append(whereClauses, partitionClauses[0]) + } + if len(partitionClauses) > 1 { + whereClauses = append(whereClauses, "(\n ("+strings.Join(partitionClauses, ") OR\n (")+")\n)") + } + + if len(whereClauses) > 0 { + query += "\n WHERE\n " + for index, clause := range whereClauses { + query += fmt.Sprintf("(%s)", clause) + if index == len(whereClauses)-1 { + break + } + query += " AND\n " + } + } + + // before proceeding, save a copy of the query and params without LIMIT/OFFSET/ORDER info + // for COUNTing all results later + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s)", query) + countParams := params[:] + + // 3- Sorting: ORDER BY clauses (from lo.Sort) + if len(lo.Sort.Fields) != len(lo.Sort.Orders) { + return nil, fmt.Errorf("sort fields length %d != sort orders length %d", len(lo.Sort.Fields), len(lo.Sort.Orders)) + } + if len(lo.Sort.Fields) > 0 { + orderByClauses := []string{} + for i, field := range lo.Sort.Fields { + columnName := toColumnName(field) + if err := l.validateColumn(columnName); err != nil { + return queryInfo, err + } + + direction := "ASC" + if lo.Sort.Orders[i] == DESC { + direction = "DESC" + } + orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction)) + } + query += "\n ORDER BY " + query += strings.Join(orderByClauses, ", ") + } else { + // make sure one default order is always picked + if l.namespaced { + query += "\n ORDER BY f.\"metadata.namespace\" ASC, f.\"metadata.name\" ASC " + } else { + query += "\n ORDER BY f.\"metadata.name\" ASC " + } + } + + // 4- Pagination: LIMIT clause (from lo.Pagination and/or lo.ChunkSize/lo.Resume) + + limitClause := "" + // take the smallest limit between lo.Pagination and lo.ChunkSize + limit := lo.Pagination.PageSize + if limit == 0 || (lo.ChunkSize > 0 && lo.ChunkSize < limit) { + limit = lo.ChunkSize + } + if limit > 0 { + limitClause = "\n LIMIT ?" + params = append(params, limit) + } + + // OFFSET clause (from lo.Pagination and/or lo.Resume) + offsetClause := "" + offset := 0 + if lo.Resume != "" { + offsetInt, err := strconv.Atoi(lo.Resume) + if err != nil { + return queryInfo, err + } + offset = offsetInt + } + if lo.Pagination.Page >= 1 { + offset += lo.Pagination.PageSize * (lo.Pagination.Page - 1) + } + if offset > 0 { + offsetClause = "\n OFFSET ?" + params = append(params, offset) + } + if limit > 0 || offset > 0 { + query += limitClause + query += offsetClause + queryInfo.countQuery = countQuery + queryInfo.countParams = countParams + queryInfo.limit = limit + queryInfo.offset = offset + } + // Otherwise leave these as default values and the executor won't do pagination work + + logrus.Debugf("ListOptionIndexer prepared statement: %v", query) + logrus.Debugf("Params: %v", params) + queryInfo.query = query + queryInfo.params = params + + return queryInfo, nil +} + +func (l *ListOptionIndexer) executeQuery(ctx context.Context, queryInfo *QueryInfo) (*unstructured.UnstructuredList, int, string, error) { + stmt := l.Prepare(queryInfo.query) + defer l.CloseStmt(stmt) + + tx, err := l.BeginTx(ctx, false) + if err != nil { + return nil, 0, "", err + } + + txStmt := tx.Stmt(stmt) + rows, err := txStmt.QueryContext(ctx, queryInfo.params...) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", &db.QueryError{QueryString: queryInfo.query, Err: err} + } + items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt()) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", err + } + + total := len(items) + if queryInfo.countQuery != "" { + countStmt := l.Prepare(queryInfo.countQuery) + defer l.CloseStmt(countStmt) + txStmt := tx.Stmt(countStmt) + rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", &db.QueryError{QueryString: queryInfo.countQuery, Err: err} + } + total, err = l.ReadInt(rows) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", fmt.Errorf("error reading query results: %w", err) + } + } + if err := tx.Commit(); err != nil { + return nil, 0, "", err + } + + continueToken := "" + limit := queryInfo.limit + offset := queryInfo.offset + if limit > 0 && offset+len(items) < total { + continueToken = fmt.Sprintf("%d", offset+limit) + } + + return toUnstructuredList(items), total, continueToken, nil +} + +func (l *ListOptionIndexer) validateColumn(column string) error { + for _, v := range l.indexedFields { + if v == column { + return nil + } + } + return fmt.Errorf("column is invalid [%s]: %w", column, ErrInvalidColumn) +} + +// buildORClause creates an SQLite compatible query that ORs conditions built from passed filters +func (l *ListOptionIndexer) buildORClauseFromFilters(index int, orFilters OrFilter, dbName string) (string, []any, error) { + var params []any + clauses := make([]string, 0, len(orFilters.Filters)) + var newParams []any + var newClause string + var err error + + for _, filter := range orFilters.Filters { + if isLabelFilter(&filter) { + newClause, newParams, err = l.getLabelFilter(index, filter, dbName) + } else { + newClause, newParams, err = l.getFieldFilter(filter) + } + if err != nil { + return "", nil, err + } + clauses = append(clauses, newClause) + params = append(params, newParams...) + } + switch len(clauses) { + case 0: + return "", params, nil + case 1: + return clauses[0], params, nil + } + return fmt.Sprintf("(%s)", strings.Join(clauses, ") OR (")), params, nil +} + +// Possible ops from the k8s parser: +// KEY = and == (same) VALUE +// KEY != VALUE +// KEY exists [] # ,KEY, => this filter +// KEY ! [] # ,!KEY, => assert KEY doesn't exist +// KEY in VALUES +// KEY notin VALUES + +func (l *ListOptionIndexer) getFieldFilter(filter Filter) (string, []any, error) { + opString := "" + escapeString := "" + columnName := toColumnName(filter.Field) + if err := l.validateColumn(columnName); err != nil { + return "", nil, err + } + switch filter.Op { + case Eq: + if filter.Partial { + opString = "LIKE" + escapeString = escapeBackslashDirective + } else { + opString = "=" + } + clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString) + return clause, []any{formatMatchTarget(filter)}, nil + case NotEq: + if filter.Partial { + opString = "NOT LIKE" + escapeString = escapeBackslashDirective + } else { + opString = "!=" + } + clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString) + return clause, []any{formatMatchTarget(filter)}, nil + + case Lt, Gt: + sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0]) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`f."%s" %s ?`, columnName, sym) + return clause, []any{target}, nil + + case Exists, NotExists: + return "", nil, errors.New("NULL and NOT NULL tests aren't supported for non-label queries") + + case In: + fallthrough + case NotIn: + target := "()" + if len(filter.Matches) > 0 { + target = fmt.Sprintf("(?%s)", strings.Repeat(", ?", len(filter.Matches)-1)) + } + opString = "IN" + if filter.Op == NotIn { + opString = "NOT IN" + } + clause := fmt.Sprintf(`f."%s" %s %s`, columnName, opString, target) + matches := make([]any, len(filter.Matches)) + for i, match := range filter.Matches { + matches[i] = match + } + return clause, matches, nil + } + + return "", nil, fmt.Errorf("unrecognized operator: %s", opString) +} + +func (l *ListOptionIndexer) getLabelFilter(index int, filter Filter, dbName string) (string, []any, error) { + opString := "" + escapeString := "" + matchFmtToUse := strictMatchFmt + labelName := filter.Field[2] + switch filter.Op { + case Eq: + if filter.Partial { + opString = "LIKE" + escapeString = escapeBackslashDirective + matchFmtToUse = matchFmt + } else { + opString = "=" + } + clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value %s ?%s`, index, index, opString, escapeString) + return clause, []any{labelName, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)}, nil + + case NotEq: + if filter.Partial { + opString = "NOT LIKE" + escapeString = escapeBackslashDirective + matchFmtToUse = matchFmt + } else { + opString = "!=" + } + subFilter := Filter{ + Field: filter.Field, + Op: NotExists, + } + existenceClause, subParams, err := l.getLabelFilter(index, subFilter, dbName) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`(%s) OR (lt%d.label = ? AND lt%d.value %s ?%s)`, existenceClause, index, index, opString, escapeString) + params := append(subParams, labelName, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)) + return clause, params, nil + + case Lt, Gt: + sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0]) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value %s ?`, index, index, sym) + return clause, []any{labelName, target}, nil + + case Exists: + clause := fmt.Sprintf(`lt%d.label = ?`, index) + return clause, []any{labelName}, nil + + case NotExists: + clause := fmt.Sprintf(`o.key NOT IN (SELECT o1.key FROM "%s" o1 + JOIN "%s_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "%s_labels" lt%di1 ON o1.key = lt%di1.key + WHERE lt%di1.label = ?)`, dbName, dbName, dbName, index, index, index) + return clause, []any{labelName}, nil + + case In: + target := "(?" + if len(filter.Matches) > 0 { + target += strings.Repeat(", ?", len(filter.Matches)-1) + } + target += ")" + clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value IN %s`, index, index, target) + matches := make([]any, len(filter.Matches)+1) + matches[0] = labelName + for i, match := range filter.Matches { + matches[i+1] = match + } + return clause, matches, nil + + case NotIn: + target := "(?" + if len(filter.Matches) > 0 { + target += strings.Repeat(", ?", len(filter.Matches)-1) + } + target += ")" + subFilter := Filter{ + Field: filter.Field, + Op: NotExists, + } + existenceClause, subParams, err := l.getLabelFilter(index, subFilter, dbName) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`(%s) OR (lt%d.label = ? AND lt%d.value NOT IN %s)`, existenceClause, index, index, target) + matches := append(subParams, labelName) + for _, match := range filter.Matches { + matches = append(matches, match) + } + return clause, matches, nil + } + return "", nil, fmt.Errorf("unrecognized operator: %s", opString) +} + +func prepareComparisonParameters(op Op, target string) (string, float64, error) { + num, err := strconv.ParseFloat(target, 32) + if err != nil { + return "", 0, err + } + switch op { + case Lt: + return "<", num, nil + case Gt: + return ">", num, nil + } + return "", 0, fmt.Errorf("unrecognized operator when expecting '<' or '>': '%s'", op) +} + +func formatMatchTarget(filter Filter) string { + format := strictMatchFmt + if filter.Partial { + format = matchFmt + } + return formatMatchTargetWithFormatter(filter.Matches[0], format) +} + +func formatMatchTargetWithFormatter(match string, format string) string { + // To allow matches on the backslash itself, the character needs to be replaced first. + // Otherwise, it will undo the following replacements. + match = strings.ReplaceAll(match, `\`, `\\`) + match = strings.ReplaceAll(match, `_`, `\_`) + match = strings.ReplaceAll(match, `%`, `\%`) + return fmt.Sprintf(format, match) +} + +// toColumnName returns the column name corresponding to a field expressed as string slice +func toColumnName(s []string) string { + return db.Sanitize(strings.Join(s, ".")) +} + +// getField extracts the value of a field expressed as a string path from an unstructured object +func getField(a any, field string) (any, error) { + subFields := extractSubFields(field) + o, ok := a.(*unstructured.Unstructured) + if !ok { + return nil, fmt.Errorf("unexpected object type, expected unstructured.Unstructured: %v", a) + } + + var obj interface{} + var found bool + var err error + obj = o.Object + for i, subField := range subFields { + switch t := obj.(type) { + case map[string]interface{}: + subField = strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]") + obj, found, err = unstructured.NestedFieldNoCopy(t, subField) + if err != nil { + return nil, err + } + if !found { + // particularly with labels/annotation indexes, it is totally possible that some objects won't have these, + // so either we this is not an error state or it could be an error state with a type that callers can check for + return nil, nil + } + case []interface{}: + if strings.HasPrefix(subField, "[") && strings.HasSuffix(subField, "]") { + key, err := strconv.Atoi(strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]")) + if err != nil { + return nil, fmt.Errorf("[listoption indexer] failed to convert subfield [%s] to int in listoption index: %w", subField, err) + } + if key >= len(t) { + return nil, fmt.Errorf("[listoption indexer] given index is too large for slice of len %d", len(t)) + } + obj = fmt.Sprintf("%v", t[key]) + } else if i == len(subFields)-1 { + result := make([]string, len(t)) + for index, v := range t { + itemVal, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err) + } + itemStr, ok := itemVal[subField].(string) + if !ok { + return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err) + } + result[index] = itemStr + } + return result, nil + } + default: + return nil, fmt.Errorf("[listoption indexer] failed to parse subfields: %v", subFields) + } + } + return obj, nil +} + +func extractSubFields(fields string) []string { + subfields := make([]string, 0) + for _, subField := range subfieldRegex.FindAllString(fields, -1) { + subfields = append(subfields, strings.TrimSuffix(subField, ".")) + } + return subfields +} + +func isLabelFilter(f *Filter) bool { + return len(f.Field) >= 2 && f.Field[0] == "metadata" && f.Field[1] == "labels" +} + +func hasLabelFilter(filters []OrFilter) bool { + for _, outerFilter := range filters { + for _, filter := range outerFilter.Filters { + if isLabelFilter(&filter) { + return true + } + } + } + return false +} + +// toUnstructuredList turns a slice of unstructured objects into an unstructured.UnstructuredList +func toUnstructuredList(items []any) *unstructured.UnstructuredList { + objectItems := make([]map[string]any, len(items)) + result := &unstructured.UnstructuredList{ + Items: make([]unstructured.Unstructured, len(items)), + Object: map[string]interface{}{"items": objectItems}, + } + for i, item := range items { + result.Items[i] = *item.(*unstructured.Unstructured) + objectItems[i] = item.(*unstructured.Unstructured).Object + } + return result +} diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go new file mode 100644 index 00000000..5352cebd --- /dev/null +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -0,0 +1,1478 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/rancher/steve/pkg/sqlcache/partition" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" +) + +func TestNewListOptionIndexer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "NewListOptionIndexer() with no errors returned, should return no error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + // create field table + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + // create field table indexes + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + loi, err := NewListOptionIndexer(fields, store, true) + assert.Nil(t, err) + assert.NotNil(t, loi) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from NewIndexer(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, false) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from Begin(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, false) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Exec() when creating fields table, should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from create-labels, should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Commit(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestListByOptions(t *testing.T) { + type testCase struct { + description string + listOptions ListOptions + partitions []partition.Partition + ns string + expectedCountStmt string + expectedCountStmtArgs []any + expectedStmt string + expectedStmtArgs []any + expectedList *unstructured.UnstructuredList + returnList []any + expectedContToken string + expectedErr error + } + + testObject := testStoreObject{Id: "something", Val: "a"} + unstrTestObjectMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(&testObject) + assert.Nil(t, err) + + var tests []testCase + tests = append(tests, testCase{ + description: "ListByOptions() with no errors returned, should not return an error", + listOptions: ListOptions{}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions() with an empty filter, should not return an error", + listOptions: ListOptions{ + Filters: []OrFilter{{[]Filter{}}}, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with ChunkSize set should set limit in prepared sql.Stmt", + listOptions: ListOptions{ChunkSize: 2}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ?`, + expectedStmtArgs: []interface{}{2}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Resume set should set offset in prepared sql.Stmt", + listOptions: ListOptions{Resume: "4"}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + OFFSET ?`, + expectedStmtArgs: []interface{}{4}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter should select where that filter is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter with Op set top NotEq should select where that filter is not true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" NOT LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter with Partial set to true should select where that partial match on that filter's value is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with multiple filters should select where any of those filters are true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"someothervalue"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somethirdvalue"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + ((f."metadata.somefield" LIKE ? ESCAPE '\') OR (f."metadata.somefield" LIKE ? ESCAPE '\') OR (f."metadata.somefield" NOT LIKE ? ESCAPE '\')) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%", "%someothervalue%", "%somethirdvalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with multiple OrFilters set should select where all OrFilters contain one filter that is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"value1"}, + Op: Eq, + Partial: false, + }, + { + Field: []string{"status", "someotherfield"}, + Matches: []string{"value2"}, + Op: NotEq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"value3"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test4", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + ((f."metadata.somefield" = ?) OR (f."status.someotherfield" != ?)) AND + (f."metadata.somefield" = ?) AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"value1", "value2", "value3", "test4"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with labels filter should select the label in the prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "guard.cattle.io"}, + Matches: []string{"lodgepole"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test41", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"guard.cattle.io", "%lodgepole%", "test41"}, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions with two labels filters should use a self-join", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "cows"}, + Matches: []string{"milk"}, + Op: Eq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "horses"}, + Matches: []string{"saddles"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test42", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + LEFT OUTER JOIN "something_labels" lt2 ON o.key = lt2.key + WHERE + (lt1.label = ? AND lt1.value = ?) AND + (lt2.label = ? AND lt2.value = ?) AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"cows", "milk", "horses", "saddles", "test42"}, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions with a mix of one label and one non-label query can still self-join", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "cows"}, + Matches: []string{"butter"}, + Op: Eq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"wheat"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test43", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value = ?) AND + (f."metadata.somefield" = ?) AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"cows", "butter", "wheat", "test43"}, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions with only one Sort.Field set should sort on that field only, in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}}, + Orders: []SortOrder{ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "test5", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.somefield" ASC`, + expectedStmtArgs: []any{"test5"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "sort one field descending", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}}, + Orders: []SortOrder{DESC}, + }, + }, + partitions: []partition.Partition{}, + ns: "test5a", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.somefield" DESC`, + expectedStmtArgs: []any{"test5a"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions sorting on two fields should sort on the first field in ascending order first and then sort on the second field in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" ASC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions sorting on two fields should sort on the first field in descending order first and then sort on the second field in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{DESC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" DESC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions sorting when # fields != # sort orders should return an error", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{DESC, ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" DESC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: fmt.Errorf("sort fields length 2 != sort orders length 3"), + }) + + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.PageSize set should set limit to PageSize in prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + PageSize: 10, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ?`, + expectedStmtArgs: []any{10}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.Page and no PageSize set should not add anything to prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + Page: 2, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.Page and PageSize set limit to PageSize and offset to PageSize * (Page - 1) in prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + PageSize: 10, + Page: 2, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ? + OFFSET ?`, + expectedStmtArgs: []any{10, 10}, + + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Namespace Partition should select only items where metadata.namespace is equal to Namespace and all other conditions are met in prepared sql.Stmt", + partitions: []partition.Partition{ + { + Namespace: "somens", + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ? AND FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somens"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a All Partition should select all items that meet all other conditions in prepared sql.Stmt", + partitions: []partition.Partition{ + { + All: true, + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Passthrough Partition should select all items that meet all other conditions prepared sql.Stmt", + partitions: []partition.Partition{ + { + Passthrough: true, + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Names Partition should select only items where metadata.name equals an items in Names and all other conditions are met in prepared sql.Stmt", + partitions: []partition.Partition{ + { + Names: sets.New[string]("someid", "someotherid"), + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.name" IN (?, ?)) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"someid", "someotherid"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + stmts := NewMockStmt(gomock.NewController(t)) + i := &Indexer{ + Store: store, + } + lii := &ListOptionIndexer{ + Indexer: i, + indexedFields: []string{"metadata.somefield", "status.someotherfield"}, + } + if test.description == "ListByOptions with multiple OrFilters set should select where all OrFilters contain one filter that is true in prepared sql.Stmt" { + fmt.Printf("stop here") + } + queryInfo, err := lii.constructQuery(test.listOptions, test.partitions, test.ns, "something") + if test.expectedErr != nil { + assert.Equal(t, test.expectedErr, err) + return + } + assert.Nil(t, err) + assert.Equal(t, test.expectedStmt, queryInfo.query) + if test.expectedStmtArgs == nil { + test.expectedStmtArgs = []any{} + } + assert.Equal(t, test.expectedStmtArgs, queryInfo.params) + assert.Equal(t, test.expectedCountStmt, queryInfo.countQuery) + assert.Equal(t, test.expectedCountStmtArgs, queryInfo.countParams) + + stmt := &sql.Stmt{} + rows := &sql.Rows{} + objType := reflect.TypeOf(testObject) + store.EXPECT().BeginTx(gomock.Any(), false).Return(txClient, nil) + txClient.EXPECT().Stmt(gomock.Any()).Return(stmts).AnyTimes() + store.EXPECT().Prepare(test.expectedStmt).Do(func(a ...any) { + fmt.Println(a) + }).Return(stmt) + if args := test.expectedStmtArgs; args != nil { + stmts.EXPECT().QueryContext(gomock.Any(), gomock.Any()).Return(rows, nil).AnyTimes() + } else if strings.Contains(test.expectedStmt, "LIMIT") { + stmts.EXPECT().QueryContext(gomock.Any(), args...).Return(rows, nil) + txClient.EXPECT().Stmt(gomock.Any()).Return(stmts) + stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil) + } else { + stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil) + } + store.EXPECT().GetType().Return(objType) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, objType, false).Return(test.returnList, nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + + if test.expectedCountStmt != "" { + store.EXPECT().Prepare(test.expectedCountStmt).Return(stmt) + //store.EXPECT().QueryForRows(context.TODO(), stmt, test.expectedCountStmtArgs...).Return(rows, nil) + store.EXPECT().ReadInt(rows).Return(len(test.expectedList.Items), nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + } + txClient.EXPECT().Commit() + list, total, contToken, err := lii.executeQuery(context.TODO(), queryInfo) + if test.expectedErr == nil { + assert.Nil(t, err) + } else { + assert.Equal(t, test.expectedErr, err) + } + assert.Equal(t, test.expectedList, list) + assert.Equal(t, len(test.expectedList.Items), total) + assert.Equal(t, test.expectedContToken, contToken) + }) + } +} + +func TestConstructQuery(t *testing.T) { + type testCase struct { + description string + listOptions ListOptions + partitions []partition.Partition + ns string + expectedCountStmt string + expectedCountStmtArgs []any + expectedStmt string + expectedStmtArgs []any + expectedErr error + } + + var tests []testCase + tests = append(tests, testCase{ + description: "TestConstructQuery: handles IN statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"somevalue"}, + Op: In, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.queryField1" IN (?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somevalue"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOT-IN statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"somevalue"}, + Op: NotIn, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.queryField1" NOT IN (?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somevalue"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles EXISTS statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Op: Exists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedErr: errors.New("NULL and NOT NULL tests aren't supported for non-label queries"), + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOT-EXISTS statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Op: NotExists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedErr: errors.New("NULL and NOT NULL tests aren't supported for non-label queries"), + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles == statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelEqualFull"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelEqualFull", "somevalue"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles == statements for label statements, match partial", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelEqualPartial"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelEqualPartial", "%somevalue%"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles != statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNotEqualFull"}, + Matches: []string{"somevalue"}, + Op: NotEq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value != ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNotEqualFull", "labelNotEqualFull", "somevalue"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles != statements for label statements, match partial", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNotEqualPartial"}, + Matches: []string{"somevalue"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value NOT LIKE ? ESCAPE '\')) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNotEqualPartial", "labelNotEqualPartial", "%somevalue%"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles multiple != statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "notEqual1"}, + Matches: []string{"value1"}, + Op: NotEq, + Partial: false, + }, + }, + }, + { + []Filter{ + { + Field: []string{"metadata", "labels", "notEqual2"}, + Matches: []string{"value2"}, + Op: NotEq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + LEFT OUTER JOIN "something_labels" lt2 ON o.key = lt2.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value != ?)) AND + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt2i1 ON o1.key = lt2i1.key + WHERE lt2i1.label = ?)) OR (lt2.label = ? AND lt2.value != ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"notEqual1", "notEqual1", "value1", "notEqual2", "notEqual2", "value2"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles IN statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelIN"}, + Matches: []string{"somevalue1", "someValue2"}, + Op: In, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value IN (?, ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelIN", "somevalue1", "someValue2"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOTIN statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNOTIN"}, + Matches: []string{"somevalue1", "someValue2"}, + Op: NotIn, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value NOT IN (?, ?))) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNOTIN", "labelNOTIN", "somevalue1", "someValue2"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles EXISTS statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelEXISTS"}, + Matches: []string{}, + Op: Exists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelEXISTS"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOTEXISTS statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNOTEXISTS"}, + Matches: []string{}, + Op: NotExists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNOTEXISTS"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles LessThan statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "numericThing"}, + Matches: []string{"5"}, + Op: Lt, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value < ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"numericThing", float64(5)}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles GreaterThan statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "numericThing"}, + Matches: []string{"35"}, + Op: Gt, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value > ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"numericThing", float64(35)}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "multiple filters with a positive label test and a negative non-label test still outer-join", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "junta"}, + Matches: []string{"esther"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"golgi"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') OR (f."metadata.queryField1" NOT LIKE ? ESCAPE '\')) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"junta", "%esther%", "%golgi%"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "multiple filters and or-filters with a positive label test and a negative non-label test still outer-join and have correct AND/ORs", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "nectar"}, + Matches: []string{"stash"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"landlady"}, + Op: NotEq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "lawn"}, + Matches: []string{"reba", "coil"}, + Op: In, + }, + { + Field: []string{"metadata", "queryField1"}, + Op: Gt, + Matches: []string{"2"}, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + LEFT OUTER JOIN "something_labels" lt2 ON o.key = lt2.key + WHERE + ((lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') OR (f."metadata.queryField1" != ?)) AND + ((lt2.label = ? AND lt2.value IN (?, ?)) OR (f."metadata.queryField1" > ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"nectar", "%stash%", "landlady", "lawn", "reba", "coil", float64(2)}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ConstructQuery: sorting when # fields < # sort orders should return an error", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{DESC, ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: "", + expectedStmtArgs: []any{}, + expectedErr: fmt.Errorf("sort fields length 2 != sort orders length 3"), + }) + + tests = append(tests, testCase{ + description: "ConstructQuery: sorting when # fields > # sort orders should return an error", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}, {"metadata", "labels", "a1"}, {"metadata", "labels", "a2"}}, + Orders: []SortOrder{DESC, ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: "", + expectedStmtArgs: []any{}, + expectedErr: fmt.Errorf("sort fields length 4 != sort orders length 3"), + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + i := &Indexer{ + Store: store, + } + lii := &ListOptionIndexer{ + Indexer: i, + indexedFields: []string{"metadata.queryField1", "status.queryField2"}, + } + queryInfo, err := lii.constructQuery(test.listOptions, test.partitions, test.ns, "something") + if test.expectedErr != nil { + assert.Equal(t, test.expectedErr, err) + return + } + assert.Nil(t, err) + assert.Equal(t, test.expectedStmt, queryInfo.query) + assert.Equal(t, test.expectedStmtArgs, queryInfo.params) + assert.Equal(t, test.expectedCountStmt, queryInfo.countQuery) + assert.Equal(t, test.expectedCountStmtArgs, queryInfo.countParams) + }) + } +} diff --git a/pkg/sqlcache/informer/listoptions.go b/pkg/sqlcache/informer/listoptions.go new file mode 100644 index 00000000..71bf6f6e --- /dev/null +++ b/pkg/sqlcache/informer/listoptions.go @@ -0,0 +1,68 @@ +package informer + +type Op string + +const ( + Eq Op = "=" + NotEq Op = "!=" + Exists Op = "Exists" + NotExists Op = "NotExists" + In Op = "In" + NotIn Op = "NotIn" + Lt Op = "Lt" + Gt Op = "Gt" +) + +// SortOrder represents whether the list should be ascending or descending. +type SortOrder int + +const ( + // ASC stands for ascending order. + ASC SortOrder = iota + // DESC stands for descending (reverse) order. + DESC +) + +// ListOptions represents the query parameters that may be included in a list request. +type ListOptions struct { + ChunkSize int + Resume string + Filters []OrFilter + Sort Sort + Pagination Pagination +} + +// Filter represents a field to filter by. +// A subfield in an object is represented in a request query using . notation, e.g. 'metadata.name'. +// The subfield is internally represented as a slice, e.g. [metadata, name]. +// Complex subfields need to be expressed with square brackets, as in `metadata.labels[zombo.com/moose]`, +// but are mapped to the string slice ["metadata", "labels", "zombo.com/moose"] +// +// If more than one value is given for the `Match` field, we do an "IN ()" test +type Filter struct { + Field []string + Matches []string + Op Op + Partial bool +} + +// OrFilter represents a set of possible fields to filter by, where an item may match any filter in the set to be included in the result. +type OrFilter struct { + Filters []Filter +} + +// Sort represents the criteria to sort on. +// The subfield to sort by is represented in a request query using . notation, e.g. 'metadata.name'. +// The subfield is internally represented as a slice, e.g. [metadata, name]. +// The order is represented by prefixing the sort key by '-', e.g. sort=-metadata.name. +// e.g. To sort internal clusters first followed by clusters in alpha order: sort=-spec.internal,spec.displayName +type Sort struct { + Fields [][]string + Orders []SortOrder +} + +// Pagination represents how to return paginated results. +type Pagination struct { + PageSize int + Page int +} diff --git a/pkg/sqlcache/informer/shared_informer_hack.go b/pkg/sqlcache/informer/shared_informer_hack.go new file mode 100644 index 00000000..c11889c9 --- /dev/null +++ b/pkg/sqlcache/informer/shared_informer_hack.go @@ -0,0 +1,22 @@ +package informer + +import ( + "reflect" + "unsafe" +) + +// UnsafeSet replaces the passed object's field value with the passed value. +func UnsafeSet(object any, field string, value any) { + rs := reflect.ValueOf(object).Elem() + rf := rs.FieldByName(field) + wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + wrf.Set(reflect.ValueOf(value)) +} + +// UnsafeGet returns the value of the passed object's for the passed field. +func UnsafeGet(object any, field string) any { + rs := reflect.ValueOf(object).Elem() + rf := rs.FieldByName(field) + wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + return wrf.Interface() +} diff --git a/pkg/sqlcache/informer/shared_informer_test.go b/pkg/sqlcache/informer/shared_informer_test.go new file mode 100644 index 00000000..bd143647 --- /dev/null +++ b/pkg/sqlcache/informer/shared_informer_test.go @@ -0,0 +1,325 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "fmt" + "k8s.io/client-go/tools/cache" + "strings" + "sync" + "testing" + "time" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + fcache "k8s.io/client-go/tools/cache/testing" + testingclock "k8s.io/utils/clock/testing" +) + +type testListener struct { + lock sync.RWMutex + resyncPeriod time.Duration + expectedItemNames sets.Set[string] + receivedItemNames []string + name string +} + +func newTestListener(name string, resyncPeriod time.Duration, expected ...string) *testListener { + l := &testListener{ + resyncPeriod: resyncPeriod, + expectedItemNames: sets.New[string](expected...), + name: name, + } + return l +} + +func (l *testListener) OnAdd(obj interface{}, isInInitialList bool) { + l.handle(obj) +} + +func (l *testListener) OnUpdate(old, new interface{}) { + l.handle(new) +} + +func (l *testListener) OnDelete(obj interface{}) { +} + +func (l *testListener) handle(obj interface{}) { + key, _ := cache.MetaNamespaceKeyFunc(obj) + fmt.Printf("%s: handle: %v\n", l.name, key) + l.lock.Lock() + defer l.lock.Unlock() + + objectMeta, _ := meta.Accessor(obj) + l.receivedItemNames = append(l.receivedItemNames, objectMeta.GetName()) +} + +func (l *testListener) ok() bool { + fmt.Println("polling") + err := wait.PollImmediate(100*time.Millisecond, 2*time.Second, func() (bool, error) { + if l.satisfiedExpectations() { + return true, nil + } + return false, nil + }) + if err != nil { + return false + } + + // wait just a bit to allow any unexpected stragglers to come in + fmt.Println("sleeping") + time.Sleep(1 * time.Second) + fmt.Println("final check") + return l.satisfiedExpectations() +} + +func (l *testListener) satisfiedExpectations() bool { + l.lock.RLock() + defer l.lock.RUnlock() + + return sets.New[string](l.receivedItemNames...).Equal(l.expectedItemNames) +} + +func TestListenerResyncPeriods(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}) + + // create the shared informer and resync every 1s + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + clock := testingclock.NewFakeClock(time.Now()) + UnsafeSet(informer, "clock", clock) + UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock) + + // listener 1, never resync + listener1 := newTestListener("listener1", 0, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener1, listener1.resyncPeriod) + + // listener 2, resync every 2s + listener2 := newTestListener("listener2", 2*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener2, listener2.resyncPeriod) + + // listener 3, resync every 3s + listener3 := newTestListener("listener3", 3*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener3, listener3.resyncPeriod) + listeners := []*testListener{listener1, listener2, listener3} + + stop := make(chan struct{}) + defer close(stop) + + go informer.Run(stop) + + // ensure all listeners got the initial List + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + // reset + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + // advance so listener2 gets a resync + clock.Step(2 * time.Second) + + // make sure listener2 got the resync + if !listener2.ok() { + t.Errorf("%s: expected %v, got %v", listener2.name, listener2.expectedItemNames, listener2.receivedItemNames) + } + + // wait a bit to give errant items a chance to go to 1 and 3 + time.Sleep(1 * time.Second) + + // make sure listeners 1 and 3 got nothing + if len(listener1.receivedItemNames) != 0 { + t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames)) + } + if len(listener3.receivedItemNames) != 0 { + t.Errorf("listener3: should not have resynced (got %d)", len(listener3.receivedItemNames)) + } + + // reset + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + // advance so listener3 gets a resync + clock.Step(1 * time.Second) + + // make sure listener3 got the resync + if !listener3.ok() { + t.Errorf("%s: expected %v, got %v", listener3.name, listener3.expectedItemNames, listener3.receivedItemNames) + } + + // wait a bit to give errant items a chance to go to 1 and 2 + time.Sleep(1 * time.Second) + + // make sure listeners 1 and 2 got nothing + if len(listener1.receivedItemNames) != 0 { + t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames)) + } + if len(listener2.receivedItemNames) != 0 { + t.Errorf("listener2: should not have resynced (got %d)", len(listener2.receivedItemNames)) + } +} + +// verify that https://github.com/kubernetes/kubernetes/issues/59822 is fixed +func TestSharedInformerInitializationRace(t *testing.T) { + source := fcache.NewFakeControllerSource() + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + listener := newTestListener("raceListener", 0) + + stop := make(chan struct{}) + go informer.AddEventHandlerWithResyncPeriod(listener, listener.resyncPeriod) + go informer.Run(stop) + close(stop) +} + +// TestSharedInformerWatchDisruption simulates a watch that was closed +// with updates to the store during that time. We ensure that handlers with +// resync and no resync see the expected state. +func TestSharedInformerWatchDisruption(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}}) + + // create the shared informer and resync every 1s + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + clock := testingclock.NewFakeClock(time.Now()) + UnsafeSet(informer, "clock", clock) + UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock) + + // listener, never resync + listenerNoResync := newTestListener("listenerNoResync", 0, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listenerNoResync, listenerNoResync.resyncPeriod) + + listenerResync := newTestListener("listenerResync", 1*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listenerResync, listenerResync.resyncPeriod) + listeners := []*testListener{listenerNoResync, listenerResync} + + stop := make(chan struct{}) + defer close(stop) + + go informer.Run(stop) + + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + // Add pod3, bump pod2 but don't broadcast it, so that the change will be seen only on relist + source.AddDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod3", UID: "pod3", ResourceVersion: "3"}}) + source.ModifyDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "4"}}) + + // Ensure that nobody saw any changes + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + listenerNoResync.expectedItemNames = sets.New[string]("pod2", "pod3") + listenerResync.expectedItemNames = sets.New[string]("pod1", "pod2", "pod3") + + // This calls shouldSync, which deletes noResync from the list of syncingListeners + clock.Step(1 * time.Second) + + // Simulate a connection loss (or even just a too-old-watch) + source.ResetWatch() + + // Wait long enough for the reflector to exit and the backoff function to start waiting + // on the fake clock, otherwise advancing the fake clock will have no effect. + // TODO: Make this deterministic by counting the number of waiters on FakeClock + time.Sleep(10 * time.Millisecond) + + // Advance the clock to cause the backoff wait to expire. + clock.Step(1601 * time.Millisecond) + + // Wait long enough for backoff to invoke ListWatch a second time and distribute events + // to listeners. + time.Sleep(10 * time.Millisecond) + + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } +} + +func TestSharedInformerErrorHandling(t *testing.T) { + source := fcache.NewFakeControllerSource() + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}) + source.ListError = fmt.Errorf("Access Denied") + + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + errCh := make(chan error) + _ = informer.SetWatchErrorHandler(func(_ *cache.Reflector, err error) { + errCh <- err + }) + + stop := make(chan struct{}) + go informer.Run(stop) + + select { + case err := <-errCh: + if !strings.Contains(err.Error(), "Access Denied") { + t.Errorf("Expected 'Access Denied' error. Actual: %v", err) + } + case <-time.After(time.Second): + t.Errorf("Timeout waiting for error handler call") + } + close(stop) +} + +func TestSharedInformerTransformer(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}}) + + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + informer.SetTransform(func(obj interface{}) (interface{}, error) { + if pod, ok := obj.(*v1.Pod); ok { + name := pod.GetName() + + if upper := strings.ToUpper(name); upper != name { + copied := pod.DeepCopyObject().(*v1.Pod) + copied.SetName(upper) + return copied, nil + } + } + return obj, nil + }) + + listenerTransformer := newTestListener("listenerTransformer", 0, "POD1", "POD2") + informer.AddEventHandler(listenerTransformer) + + stop := make(chan struct{}) + go informer.Run(stop) + defer close(stop) + + if !listenerTransformer.ok() { + t.Errorf("%s: expected %v, got %v", listenerTransformer.name, listenerTransformer.expectedItemNames, listenerTransformer.receivedItemNames) + } +} diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go new file mode 100644 index 00000000..c269b01b --- /dev/null +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -0,0 +1,347 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: Store) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockStore) Add(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockStoreMockRecorder) Add(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockStore)(nil).Add), arg0) +} + +// BeginTx mocks base method. +func (m *MockStore) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockStoreMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockStore)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockStore) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockStoreMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockStore)(nil).CloseStmt), arg0) +} + +// Delete mocks base method. +func (m *MockStore) Delete(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockStoreMockRecorder) Delete(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete), arg0) +} + +// Get mocks base method. +func (m *MockStore) Get(arg0 any) (any, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Get indicates an expected call of Get. +func (mr *MockStoreMockRecorder) Get(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), arg0) +} + +// GetByKey mocks base method. +func (m *MockStore) GetByKey(arg0 string) (any, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByKey", arg0) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetByKey indicates an expected call of GetByKey. +func (mr *MockStoreMockRecorder) GetByKey(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByKey", reflect.TypeOf((*MockStore)(nil).GetByKey), arg0) +} + +// GetName mocks base method. +func (m *MockStore) GetName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetName") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetName indicates an expected call of GetName. +func (mr *MockStoreMockRecorder) GetName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockStore)(nil).GetName)) +} + +// GetShouldEncrypt mocks base method. +func (m *MockStore) GetShouldEncrypt() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetShouldEncrypt") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GetShouldEncrypt indicates an expected call of GetShouldEncrypt. +func (mr *MockStoreMockRecorder) GetShouldEncrypt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShouldEncrypt", reflect.TypeOf((*MockStore)(nil).GetShouldEncrypt)) +} + +// GetType mocks base method. +func (m *MockStore) GetType() reflect.Type { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetType") + ret0, _ := ret[0].(reflect.Type) + return ret0 +} + +// GetType indicates an expected call of GetType. +func (mr *MockStoreMockRecorder) GetType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetType", reflect.TypeOf((*MockStore)(nil).GetType)) +} + +// List mocks base method. +func (m *MockStore) List() []any { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]any) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockStoreMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStore)(nil).List)) +} + +// ListKeys mocks base method. +func (m *MockStore) ListKeys() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListKeys") + ret0, _ := ret[0].([]string) + return ret0 +} + +// ListKeys indicates an expected call of ListKeys. +func (mr *MockStoreMockRecorder) ListKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKeys", reflect.TypeOf((*MockStore)(nil).ListKeys)) +} + +// Prepare mocks base method. +func (m *MockStore) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockStoreMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockStore)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockStore) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockStoreMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockStore)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockStore) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockStoreMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockStore)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockStore) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockStoreMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockStore)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockStore) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockStoreMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockStore)(nil).ReadStrings), arg0) +} + +// RegisterAfterDelete mocks base method. +func (m *MockStore) RegisterAfterDelete(arg0 func(string, db.TXClient) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterDelete", arg0) +} + +// RegisterAfterDelete indicates an expected call of RegisterAfterDelete. +func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0) +} + +// RegisterAfterUpsert mocks base method. +func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, db.TXClient) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterUpsert", arg0) +} + +// RegisterAfterUpsert indicates an expected call of RegisterAfterUpsert. +func (mr *MockStoreMockRecorder) RegisterAfterUpsert(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterUpsert", reflect.TypeOf((*MockStore)(nil).RegisterAfterUpsert), arg0) +} + +// Replace mocks base method. +func (m *MockStore) Replace(arg0 []any, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Replace", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Replace indicates an expected call of Replace. +func (mr *MockStoreMockRecorder) Replace(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Replace", reflect.TypeOf((*MockStore)(nil).Replace), arg0, arg1) +} + +// Resync mocks base method. +func (m *MockStore) Resync() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Resync") + ret0, _ := ret[0].(error) + return ret0 +} + +// Resync indicates an expected call of Resync. +func (mr *MockStoreMockRecorder) Resync() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resync", reflect.TypeOf((*MockStore)(nil).Resync)) +} + +// Update mocks base method. +func (m *MockStore) Update(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockStoreMockRecorder) Update(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStore)(nil).Update), arg0) +} diff --git a/pkg/sqlcache/informer/store_mocks_test.go b/pkg/sqlcache/informer/store_mocks_test.go new file mode 100644 index 00000000..c1c7d426 --- /dev/null +++ b/pkg/sqlcache/informer/store_mocks_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/informer/tx_mocks_test.go b/pkg/sqlcache/informer/tx_mocks_test.go new file mode 100644 index 00000000..9383411d --- /dev/null +++ b/pkg/sqlcache/informer/tx_mocks_test.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} diff --git a/pkg/sqlcache/integration_test.go b/pkg/sqlcache/integration_test.go new file mode 100644 index 00000000..58715918 --- /dev/null +++ b/pkg/sqlcache/integration_test.go @@ -0,0 +1,365 @@ +package sql + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/suite" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/envtest" + + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" +) + +const testNamespace = "sql-test" + +var defaultPartition = partition.Partition{ + All: true, +} + +type IntegrationSuite struct { + suite.Suite + testEnv envtest.Environment + clientset kubernetes.Clientset + restCfg rest.Config +} + +func (i *IntegrationSuite) SetupSuite() { + i.testEnv = envtest.Environment{} + restCfg, err := i.testEnv.Start() + i.Require().NoError(err, "error when starting env test - this is likely because setup-envtest wasn't done. Check the README for more information") + i.restCfg = *restCfg + clientset, err := kubernetes.NewForConfig(restCfg) + i.Require().NoError(err) + i.clientset = *clientset + testNs := v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: testNamespace, + }, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, err = i.clientset.CoreV1().Namespaces().Create(ctx, &testNs, metav1.CreateOptions{}) + i.Require().NoError(err) +} + +func (i *IntegrationSuite) TearDownSuite() { + err := i.testEnv.Stop() + i.Require().NoError(err) +} + +func (i *IntegrationSuite) TestSQLCacheFilters() { + fields := [][]string{{`metadata`, `annotations[somekey]`}} + require := i.Require() + configMapWithAnnotations := func(name string, annotations map[string]string) v1.ConfigMap { + return v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: testNamespace, + Annotations: annotations, + }, + } + } + createConfigMaps := func(configMaps ...v1.ConfigMap) { + for _, configMap := range configMaps { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + configMapClient := i.clientset.CoreV1().ConfigMaps(testNamespace) + _, err := configMapClient.Create(ctx, &configMap, metav1.CreateOptions{}) + require.NoError(err) + // avoiding defer in a for loop + cancel() + } + } + + // we create some configmaps before the cache starts and some after so that we can test the initial list and + // subsequent watches to make sure both work + // matches the filter for somekey == somevalue + matches := configMapWithAnnotations("matches-filter", map[string]string{"somekey": "somevalue"}) + // partial match for somekey == somevalue (different suffix) + partialMatches := configMapWithAnnotations("partial-matches", map[string]string{"somekey": "somevaluehere"}) + specialCharacterMatch := configMapWithAnnotations("special-character-matches", map[string]string{"somekey": "c%%l_value"}) + backSlashCharacterMatch := configMapWithAnnotations("backslash-character-matches", map[string]string{"somekey": `my\windows\path`}) + createConfigMaps(matches, partialMatches, specialCharacterMatch, backSlashCharacterMatch) + + cache, cacheFactory, err := i.createCacheAndFactory(fields, nil) + require.NoError(err) + defer cacheFactory.Reset() + + // doesn't match the filter for somekey == somevalue + notMatches := configMapWithAnnotations("not-matches-filter", map[string]string{"somekey": "notequal"}) + // has no annotations, shouldn't match any filter + missing := configMapWithAnnotations("missing", nil) + createConfigMaps(notMatches, missing) + + configMapNames := []string{matches.Name, partialMatches.Name, notMatches.Name, missing.Name, specialCharacterMatch.Name, backSlashCharacterMatch.Name} + err = i.waitForCacheReady(configMapNames, testNamespace, cache) + require.NoError(err) + + orFiltersForFilters := func(filters ...informer.Filter) []informer.OrFilter { + return []informer.OrFilter{ + { + Filters: filters, + }, + } + } + tests := []struct { + name string + filters []informer.OrFilter + wantNames []string + }{ + { + name: "matches filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: false, + }), + wantNames: []string{"matches-filter"}, + }, + { + name: "partial matches filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"matches-filter", "partial-matches"}, + }, + { + name: "no matches for filter with underscore as it is interpreted literally", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalu_"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: nil, + }, + { + name: "no matches for filter with percent sign as it is interpreted literally", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalu%"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: nil, + }, + { + name: "match with special characters", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"c%%l_value"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"special-character-matches"}, + }, + { + name: "match with literal backslash character", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{`my\windows\path`}, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"backslash-character-matches"}, + }, + { + name: "not eq filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.NotEq, + Partial: false, + }), + wantNames: []string{"partial-matches", "not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"}, + }, + { + name: "partial not eq filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.NotEq, + Partial: true, + }), + wantNames: []string{"not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"}, + }, + { + name: "multiple or filters match", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }, + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"notequal"}, + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{"matches-filter", "partial-matches", "not-matches-filter"}, + }, + { + name: "or filters on different fields", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }, + informer.Filter{ + Field: []string{`metadata`, `name`}, + Matches: []string{"missing"}, + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{"matches-filter", "partial-matches", "missing"}, + }, + { + name: "and filters, both must match", + filters: []informer.OrFilter{ + { + Filters: []informer.Filter{ + { + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }, + }, + }, + { + Filters: []informer.Filter{ + { + Field: []string{`metadata`, `name`}, + Matches: []string{"matches-filter"}, + Op: informer.Eq, + Partial: false, + }, + }, + }, + }, + wantNames: []string{"matches-filter"}, + }, + { + name: "no matches", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"valueNotRepresented"}, + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{}, + }, + } + + for _, test := range tests { + test := test + i.Run(test.name, func() { + options := informer.ListOptions{ + Filters: test.filters, + } + partitions := []partition.Partition{defaultPartition} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + cfgMaps, total, continueToken, err := cache.ListByOptions(ctx, options, partitions, testNamespace) + i.Require().NoError(err) + // since there's no additional pages, the continue token should be empty + i.Require().Equal("", continueToken) + i.Require().NotNil(cfgMaps) + // assert instead of require so that we can see the full evaluation of # of resources returned + i.Assert().Equal(len(test.wantNames), total) + i.Assert().Len(cfgMaps.Items, len(test.wantNames)) + requireNames := sets.Set[string]{} + requireNames.Insert(test.wantNames...) + gotNames := sets.Set[string]{} + for _, configMap := range cfgMaps.Items { + gotNames.Insert(configMap.GetName()) + } + i.Require().True(requireNames.Equal(gotNames), "wanted %v, got %v", requireNames, gotNames) + }) + } +} + +func (i *IntegrationSuite) createCacheAndFactory(fields [][]string, transformFunc cache.TransformFunc) (*factory.Cache, *factory.CacheFactory, error) { + cacheFactory, err := factory.NewCacheFactory() + if err != nil { + return nil, nil, fmt.Errorf("unable to make factory: %w", err) + } + dynamicClient, err := dynamic.NewForConfig(&i.restCfg) + if err != nil { + return nil, nil, fmt.Errorf("unable to make dynamicClient: %w", err) + } + configMapGVK := schema.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "ConfigMap", + } + configMapGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "configmaps", + } + dynamicResource := dynamicClient.Resource(configMapGVR).Namespace(testNamespace) + cache, err := cacheFactory.CacheFor(fields, transformFunc, dynamicResource, configMapGVK, true, true) + if err != nil { + return nil, nil, fmt.Errorf("unable to make cache: %w", err) + } + return &cache, cacheFactory, nil +} + +func (i *IntegrationSuite) waitForCacheReady(readyResourceNames []string, namespace string, cache *factory.Cache) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + return wait.PollUntilContextCancel(ctx, time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) { + var options informer.ListOptions + partitions := []partition.Partition{defaultPartition} + cacheCtx, cacheCancel := context.WithTimeout(ctx, time.Second*5) + defer cacheCancel() + currentResources, total, _, err := cache.ListByOptions(cacheCtx, options, partitions, namespace) + if err != nil { + // note that we don't return the error since that would stop the polling + return false, nil + } + if total != len(readyResourceNames) { + return false, nil + } + wantNames := sets.Set[string]{} + wantNames.Insert(readyResourceNames...) + gotNames := sets.Set[string]{} + for _, current := range currentResources.Items { + name := current.GetName() + if !wantNames.Has(name) { + return true, fmt.Errorf("got resource %s which wasn't expected", name) + } + gotNames.Insert(name) + } + return wantNames.Equal(gotNames), nil + }) +} + +func TestIntegrationSuite(t *testing.T) { + suite.Run(t, new(IntegrationSuite)) +} diff --git a/pkg/sqlcache/partition/partition.go b/pkg/sqlcache/partition/partition.go new file mode 100644 index 00000000..dd21a60e --- /dev/null +++ b/pkg/sqlcache/partition/partition.go @@ -0,0 +1,24 @@ +/* +Package partition represents listing parameters. They can be used to specify which namespaces a caller would like included +in a response, or which specific objects they are looking for. +*/ +package partition + +import ( + "k8s.io/apimachinery/pkg/util/sets" +) + +// Partition represents filtering of a request's results +type Partition struct { + // if true, do not apply any filtering, return all results. Overrides all other fields + Passthrough bool + + // if non-empty, only resources in the specified namespaces will be returned + Namespace string + + // if true, return all results, while still honoring Namespace. Overrides Names + All bool + + // if non-empty, only resources with matching names will be returned + Names sets.Set[string] +} diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go new file mode 100644 index 00000000..75f70b6e --- /dev/null +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +// + +// Package store is a generated GoMock package. +package store + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go new file mode 100644 index 00000000..a228ee86 --- /dev/null +++ b/pkg/sqlcache/store/store.go @@ -0,0 +1,362 @@ +/* +Package store contains the sql backed store. It persists objects to a sqlite database. +*/ +package store + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/rancher/lasso/pkg/log" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" + "k8s.io/client-go/tools/cache" + + // needed for drivers + _ "modernc.org/sqlite" +) + +const ( + upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)` + deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?` + getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?` + listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"` + listKeysStmtFmt = `SELECT key FROM "%s"` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" ( + key TEXT UNIQUE NOT NULL PRIMARY KEY, + object BLOB, + objectnonce BLOB, + dekid INTEGER + )` +) + +// Store is a SQLite-backed cache.Store +type Store struct { + DBClient + + name string + typ reflect.Type + keyFunc cache.KeyFunc + shouldEncrypt bool + + upsertQuery string + deleteQuery string + getQuery string + listQuery string + listKeysQuery string + + upsertStmt *sql.Stmt + deleteStmt *sql.Stmt + getStmt *sql.Stmt + listStmt *sql.Stmt + listKeysStmt *sql.Stmt + + afterUpsert []func(key string, obj any, tx db.TXClient) error + afterDelete []func(key string, tx db.TXClient) error +} + +// Test that Store implements cache.Indexer +var _ cache.Store = (*Store)(nil) + +type DBClient interface { + BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) + Prepare(stmt string) *sql.Stmt + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows db.Rows) ([]string, error) + ReadInt(rows db.Rows) (int, error) + Upsert(tx db.TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error + CloseStmt(closable db.Closable) error +} + +// NewStore creates a SQLite-backed cache.Store for objects of the given example type +func NewStore(example any, keyFunc cache.KeyFunc, c DBClient, shouldEncrypt bool, name string) (*Store, error) { + s := &Store{ + name: name, + typ: reflect.TypeOf(example), + DBClient: c, + keyFunc: keyFunc, + shouldEncrypt: shouldEncrypt, + afterUpsert: []func(key string, obj any, tx db.TXClient) error{}, + afterDelete: []func(key string, tx db.TXClient) error{}, + } + + // once multiple informerfactories are needed, this can accept the case where table already exists error is received + txC, err := s.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + dbName := db.Sanitize(s.name) + createTableQuery := fmt.Sprintf(createTableFmt, dbName) + err = txC.Exec(createTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createTableQuery, Err: err} + } + + err = txC.Commit() + if err != nil { + return nil, err + } + + s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName) + s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName) + s.getQuery = fmt.Sprintf(getStmtFmt, dbName) + s.listQuery = fmt.Sprintf(listStmtFmt, dbName) + s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName) + + s.upsertStmt = s.Prepare(s.upsertQuery) + s.deleteStmt = s.Prepare(s.deleteQuery) + s.getStmt = s.Prepare(s.getQuery) + s.listStmt = s.Prepare(s.listQuery) + s.listKeysStmt = s.Prepare(s.listKeysQuery) + + return s, nil +} + +/* Core methods */ +// upsert saves an obj with its key, or updates key with obj if it exists in this Store +func (s *Store) upsert(key string, obj any) error { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + err = s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return &db.QueryError{QueryString: s.upsertQuery, Err: err} + } + + err = s.runAfterUpsert(key, obj, tx) + if err != nil { + return err + } + + return tx.Commit() +} + +// deleteByKey deletes the object associated with key, if it exists in this Store +func (s *Store) deleteByKey(key string) error { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + err = tx.StmtExec(tx.Stmt(s.deleteStmt), key) + if err != nil { + return &db.QueryError{QueryString: s.deleteQuery, Err: err} + } + + err = s.runAfterDelete(key, tx) + if err != nil { + return err + } + + return tx.Commit() +} + +// GetByKey returns the object associated with the given object's key +func (s *Store) GetByKey(key string) (item any, exists bool, err error) { + rows, err := s.QueryForRows(context.TODO(), s.getStmt, key) + if err != nil { + return nil, false, &db.QueryError{QueryString: s.getQuery, Err: err} + } + result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt) + if err != nil { + return nil, false, err + } + + if len(result) == 0 { + return nil, false, nil + } + + return result[0], true, nil +} + +/* Satisfy cache.Store */ + +// Add saves an obj, or updates it if it exists in this Store +func (s *Store) Add(obj any) error { + key, err := s.keyFunc(obj) + if err != nil { + return err + } + + err = s.upsert(key, obj) + if err != nil { + log.Errorf("Error in Store.Add for type %v: %v", s.name, err) + return err + } + return nil +} + +// Update saves an obj, or updates it if it exists in this Store +func (s *Store) Update(obj any) error { + return s.Add(obj) +} + +// Delete deletes the given object, if it exists in this Store +func (s *Store) Delete(obj any) error { + key, err := s.keyFunc(obj) + if err != nil { + return err + } + err = s.deleteByKey(key) + if err != nil { + log.Errorf("Error in Store.Delete for type %v: %v", s.name, err) + return err + } + return nil +} + +// List returns a list of all the currently known objects +// Note: I/O errors will panic this function, as the interface signature does not allow returning errors +func (s *Store) List() []any { + rows, err := s.QueryForRows(context.TODO(), s.listStmt) + if err != nil { + panic(&db.QueryError{QueryString: s.listQuery, Err: err}) + } + result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt) + if err != nil { + panic(fmt.Errorf("error in Store.List: %w", err)) + } + return result +} + +// ListKeys returns a list of all the keys currently in this Store +// Note: Atm it doesn't appear returning nil in the case of an error has any detrimental effects. An error is not +// uncommon enough nor does it appear to necessitate a panic. +func (s *Store) ListKeys() []string { + rows, err := s.QueryForRows(context.TODO(), s.listKeysStmt) + if err != nil { + fmt.Printf("Unexpected error in store.ListKeys: while executing query: %s got error: %v", s.listKeysQuery, err) + return []string{} + } + result, err := s.ReadStrings(rows) + if err != nil { + fmt.Printf("Unexpected error in store.ListKeys: %v\n", err) + return []string{} + } + return result +} + +// Get returns the object with the same key as obj +func (s *Store) Get(obj any) (item any, exists bool, err error) { + key, err := s.keyFunc(obj) + if err != nil { + return nil, false, err + } + + return s.GetByKey(key) +} + +// Replace will delete the contents of the Store, using instead the given list +func (s *Store) Replace(objects []any, _ string) error { + objectMap := map[string]any{} + + for _, object := range objects { + key, err := s.keyFunc(object) + if err != nil { + return err + } + objectMap[key] = object + } + return s.replaceByKey(objectMap) +} + +// replaceByKey will delete the contents of the Store, using instead the given key to obj map +func (s *Store) replaceByKey(objects map[string]any) error { + txC, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + txCListKeys := txC.Stmt(s.listKeysStmt) + + rows, err := s.QueryForRows(context.TODO(), txCListKeys) + if err != nil { + return err + } + keys, err := s.ReadStrings(rows) + if err != nil { + return err + } + + for _, key := range keys { + err = txC.StmtExec(txC.Stmt(s.deleteStmt), key) + if err != nil { + return err + } + err = s.runAfterDelete(key, txC) + if err != nil { + return err + } + } + + for key, obj := range objects { + err = s.Upsert(txC, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return err + } + err = s.runAfterUpsert(key, obj, txC) + if err != nil { + return err + } + } + + return txC.Commit() +} + +// Resync is a no-op and is deprecated +func (s *Store) Resync() error { + return nil +} + +/* Utilities */ + +// RegisterAfterUpsert registers a func to be called after each upsert +func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC db.TXClient) error) { + s.afterUpsert = append(s.afterUpsert, f) +} + +func (s *Store) GetName() string { + return s.name +} + +func (s *Store) GetShouldEncrypt() bool { + return s.shouldEncrypt +} + +func (s *Store) GetType() reflect.Type { + return s.typ +} + +// keep +// runAfterUpsert executes functions registered to run after upsert +func (s *Store) runAfterUpsert(key string, obj any, txC db.TXClient) error { + for _, f := range s.afterUpsert { + err := f(key, obj, txC) + if err != nil { + return err + } + } + return nil +} + +// RegisterAfterDelete registers a func to be called after each deletion +func (s *Store) RegisterAfterDelete(f func(key string, txC db.TXClient) error) { + s.afterDelete = append(s.afterDelete, f) +} + +// keep +// runAfterDelete executes functions registered to run after upsert +func (s *Store) runAfterDelete(key string, txC db.TXClient) error { + for _, f := range s.afterDelete { + err := f(key, txC) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/sqlcache/store/store_mocks_test.go b/pkg/sqlcache/store/store_mocks_test.go new file mode 100644 index 00000000..d30df82b --- /dev/null +++ b/pkg/sqlcache/store/store_mocks_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +// + +// Package store is a generated GoMock package. +package store + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go new file mode 100644 index 00000000..1d4e2613 --- /dev/null +++ b/pkg/sqlcache/store/store_test.go @@ -0,0 +1,646 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package store + +// Mocks for this test are generated with the following command. +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/rancher/steve/pkg/sqlcache/db" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func testStoreKeyFunc(obj interface{}) (string, error) { + return obj.(testStoreObject).Id, nil +} + +type testStoreObject struct { + Id string + Val string +} + +func TestAdd(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Add with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Add(testObject) + assert.Nil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Commit().Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + + var count int + store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx db.TXClient) error { + count++ + return nil + }) + err := store.Add(testObject) + assert.Nil(t, err) + assert.Equal(t, count, 1) + }, + }) + + tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + return fmt.Errorf("error") + }) + err := store.Add(testObject) + assert.NotNil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Add with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + store := SetupStore(t, c, shouldEncrypt) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Update updates the given object in the accumulator associated with the given object's key +func TestUpdate(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Update with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Update(testObject) + assert.Nil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + + var count int + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + count++ + return nil + }) + err := store.Update(testObject) + assert.Nil(t, err) + assert.Equal(t, count, 1) + }, + }) + + tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + return fmt.Errorf("error") + }) + err := store.Update(testObject) + assert.NotNil(t, err) + }, + }) + + tests = append(tests, testCase{description: "Update with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + store := SetupStore(t, c, shouldEncrypt) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Delete deletes the given object from the accumulator associated with the given object's key +func TestDelete(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Delete with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Delete(testObject) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + // tx.EXPECT(). + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("error")) + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// List returns a list of all the currently non-empty accumulators +func TestList(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + tests = append(tests, testCase{description: "List with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + items := store.List() + assert.Len(t, items, 0) + }, + }) + tests = append(tests, testCase{description: "List with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + fakeItemsToReturn := []any{"something1", 2, false} + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(fakeItemsToReturn, nil) + items := store.List() + assert.Equal(t, fakeItemsToReturn, items) + }, + }) + tests = append(tests, testCase{description: "List with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + defer func() { + recover() + }() + _ = store.List() + assert.Fail(t, "Store list should panic when ReadObjects returns an error") + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// ListKeys returns a list of all the keys currently associated with non-empty accumulators +func TestListKeys(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListKeys with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{"a", "b", "c"}, nil) + keys := store.ListKeys() + assert.Len(t, keys, 3) + }, + }) + + tests = append(tests, testCase{description: "ListKeys with DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + keys := store.ListKeys() + assert.Len(t, keys, 0) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Get returns the accumulator associated with the given object's key +func TestGet(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "Get with no DB Client errors and object exists", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil) + item, exists, err := store.Get(testObject) + assert.Nil(t, err) + assert.Equal(t, item, testObject) + assert.True(t, exists) + }, + }) + tests = append(tests, testCase{description: "Get with no DB Client errors and object does not exist", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + item, exists, err := store.Get(testObject) + assert.Nil(t, err) + assert.Equal(t, item, nil) + assert.False(t, exists) + }, + }) + tests = append(tests, testCase{description: "Get with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + _, _, err := store.Get(testObject) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// GetByKey returns the accumulator associated with the given key +func TestGetByKey(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item exists", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil) + item, exists, err := store.GetByKey(testObject.Id) + assert.Nil(t, err) + assert.Equal(t, item, testObject) + assert.True(t, exists) + }, + }) + tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item does not exist", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + item, exists, err := store.GetByKey(testObject.Id) + assert.Nil(t, err) + assert.Equal(t, nil, item) + assert.False(t, exists) + }, + }) + tests = append(tests, testCase{description: "GetByKey with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + _, _, err := store.GetByKey(testObject.Id) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Replace will delete the contents of the store, using instead the +// given list. Store takes ownership of the list, you should not reference +// it after calling this function. +func TestReplace(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "Replace with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) + txC.EXPECT().Commit() + err := store.Replace([]any{testObject}, testObject.Id) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{}, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) + txC.EXPECT().Commit() + err := store.Replace([]any{testObject}, testObject.Id) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with no DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Resync is meaningless in the terms appearing here but has +// meaning in some implementations that have non-trivial +// additional behavior (e.g., DeltaFIFO). +func TestResync(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + tests = append(tests, testCase{description: "Resync shouldn't call the client, panic, or do anything else", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + err := store.Resync() + assert.Nil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +func SetupMockDB(t *testing.T) (*MockDBClient, *MockTXClient) { + dbC := NewMockDBClient(gomock.NewController(t)) // add functionality once store expectation are known + txC := NewMockTXClient(gomock.NewController(t)) + // stmt := NewMockStmt(gomock.NewController()) + txC.EXPECT().Exec(fmt.Sprintf(createTableFmt, "testStoreObject")).Return(nil) + txC.EXPECT().Commit().Return(nil) + dbC.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + + // use stmt mock here + dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + + return dbC, txC +} +func SetupStore(t *testing.T, client *MockDBClient, shouldEncrypt bool) *Store { + store, err := NewStore(testStoreObject{}, testStoreKeyFunc, client, shouldEncrypt, "testStoreObject") + if err != nil { + t.Error(err) + } + return store +} diff --git a/pkg/sqlcache/store/tx_mocks_test.go b/pkg/sqlcache/store/tx_mocks_test.go new file mode 100644 index 00000000..0c05ab7f --- /dev/null +++ b/pkg/sqlcache/store/tx_mocks_test.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt +// + +// Package store is a generated GoMock package. +package store + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} diff --git a/pkg/stores/sqlpartition/listprocessor/processor.go b/pkg/stores/sqlpartition/listprocessor/processor.go index 0ea15788..ee2f49b3 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor.go +++ b/pkg/stores/sqlpartition/listprocessor/processor.go @@ -10,8 +10,8 @@ import ( "github.com/rancher/apiserver/pkg/apierror" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/schemas/validation" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) @@ -79,7 +79,7 @@ func ParseQuery(apiOp *types.APIRequest, namespaceCache Cache) (informer.ListOpt } usePartialMatch := !(strings.HasPrefix(filter[1], `'`) && strings.HasSuffix(filter[1], `'`)) value := strings.TrimSuffix(strings.TrimPrefix(filter[1], "'"), "'") - orFilter.Filters = append(orFilter.Filters, informer.Filter{Field: strings.Split(filter[0], "."), Match: value, Op: op, Partial: usePartialMatch}) + orFilter.Filters = append(orFilter.Filters, informer.Filter{Field: strings.Split(filter[0], "."), Matches: []string{value}, Op: op, Partial: usePartialMatch}) } filterOpts = append(filterOpts, orFilter) } @@ -91,20 +91,20 @@ func ParseQuery(apiOp *types.APIRequest, namespaceCache Cache) (informer.ListOpt sortParts := strings.SplitN(sortKeys, ",", 2) primaryField := sortParts[0] if primaryField != "" && primaryField[0] == '-' { - sortOpts.PrimaryOrder = informer.DESC + sortOpts.Orders = append(sortOpts.Orders, informer.DESC) primaryField = primaryField[1:] } if primaryField != "" { - sortOpts.PrimaryField = strings.Split(primaryField, ".") + sortOpts.Fields = append(sortOpts.Fields, strings.Split(primaryField, ".")) } if len(sortParts) > 1 { secondaryField := sortParts[1] if secondaryField != "" && secondaryField[0] == '-' { - sortOpts.SecondaryOrder = informer.DESC + sortOpts.Orders = append(sortOpts.Orders, informer.DESC) secondaryField = secondaryField[1:] } if secondaryField != "" { - sortOpts.SecondaryField = strings.Split(secondaryField, ".") + sortOpts.Fields = append(sortOpts.Fields, strings.Split(secondaryField, ".")) } } } @@ -170,14 +170,14 @@ func parseNamespaceOrProjectFilters(ctx context.Context, projOrNS string, op inf { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: pn, - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{pn}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: pn, - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{pn}, + Op: informer.Eq, }, }, }, @@ -189,7 +189,7 @@ func parseNamespaceOrProjectFilters(ctx context.Context, projOrNS string, op inf for _, item := range uList.Items { filters = append(filters, informer.Filter{ Field: []string{"metadata", "namespace"}, - Match: item.GetName(), + Matches: []string{item.GetName()}, Op: op, Partial: false, }) diff --git a/pkg/stores/sqlpartition/listprocessor/processor_test.go b/pkg/stores/sqlpartition/listprocessor/processor_test.go index 08b6ea6a..0aa3207d 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor_test.go +++ b/pkg/stores/sqlpartition/listprocessor/processor_test.go @@ -8,8 +8,8 @@ import ( "testing" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -60,7 +60,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"metadata", "namespace"}, - Match: "ns1", + Matches: []string{"ns1"}, Op: "", Partial: false, }, @@ -89,14 +89,14 @@ func TestParseQuery(t *testing.T) { { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, }, }, @@ -120,7 +120,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"metadata", "namespace"}, - Match: "ns1", + Matches: []string{"ns1"}, Op: "", Partial: false, }, @@ -139,14 +139,14 @@ func TestParseQuery(t *testing.T) { { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, }, }, @@ -170,7 +170,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"metadata", "namespace"}, - Match: "ns1", + Matches: []string{"ns1"}, Op: "", Partial: false, }, @@ -192,14 +192,14 @@ func TestParseQuery(t *testing.T) { { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, }, }, @@ -222,7 +222,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: true, }, @@ -251,7 +251,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: false, }, @@ -280,7 +280,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: true, }, @@ -290,7 +290,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"b"}, - Match: "d", + Matches: []string{"d"}, Op: "", Partial: true, }, @@ -320,13 +320,13 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: true, }, { Field: []string{"b"}, - Match: "d", + Matches: []string{"d"}, Op: "", Partial: true, }, @@ -352,7 +352,9 @@ func TestParseQuery(t *testing.T) { expectedLO: informer.ListOptions{ ChunkSize: defaultLimit, Sort: informer.Sort{ - PrimaryField: []string{"metadata", "name"}, + Fields: [][]string{ + {"metadata", "name"}, + }, }, Filters: make([]informer.OrFilter, 0), Pagination: informer.Pagination{ @@ -374,8 +376,8 @@ func TestParseQuery(t *testing.T) { expectedLO: informer.ListOptions{ ChunkSize: defaultLimit, Sort: informer.Sort{ - PrimaryField: []string{"metadata", "name"}, - PrimaryOrder: informer.DESC, + Fields: [][]string{{"metadata", "name"}}, + Orders: []informer.SortOrder{informer.DESC}, }, Filters: make([]informer.OrFilter, 0), Pagination: informer.Pagination{ @@ -397,10 +399,13 @@ func TestParseQuery(t *testing.T) { expectedLO: informer.ListOptions{ ChunkSize: defaultLimit, Sort: informer.Sort{ - PrimaryField: []string{"metadata", "name"}, - PrimaryOrder: informer.DESC, - SecondaryField: []string{"spec", "something"}, - SecondaryOrder: informer.ASC, + Fields: [][]string{ + {"metadata", "name"}, + {"spec", "something"}, + }, + Orders: []informer.SortOrder{ + informer.DESC, + }, }, Filters: make([]informer.OrFilter, 0), Pagination: informer.Pagination{ diff --git a/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go b/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go index 693261a7..06598043 100644 --- a/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go +++ b/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/pkg/stores/sqlpartition/partition_mocks_test.go b/pkg/stores/sqlpartition/partition_mocks_test.go index ea72b8c1..687b9006 100644 --- a/pkg/stores/sqlpartition/partition_mocks_test.go +++ b/pkg/stores/sqlpartition/partition_mocks_test.go @@ -13,7 +13,7 @@ import ( reflect "reflect" types "github.com/rancher/apiserver/pkg/types" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" watch "k8s.io/apimachinery/pkg/watch" diff --git a/pkg/stores/sqlpartition/partitioner.go b/pkg/stores/sqlpartition/partitioner.go index ffee38df..b3b74f9b 100644 --- a/pkg/stores/sqlpartition/partitioner.go +++ b/pkg/stores/sqlpartition/partitioner.go @@ -5,9 +5,9 @@ import ( "sort" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/attributes" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/kv" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlpartition/partitioner_test.go b/pkg/stores/sqlpartition/partitioner_test.go index caba1897..cdc93ff1 100644 --- a/pkg/stores/sqlpartition/partitioner_test.go +++ b/pkg/stores/sqlpartition/partitioner_test.go @@ -6,7 +6,7 @@ import ( "go.uber.org/mock/gomock" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/stretchr/testify/assert" diff --git a/pkg/stores/sqlpartition/store.go b/pkg/stores/sqlpartition/store.go index 145d5e72..f4ebb325 100644 --- a/pkg/stores/sqlpartition/store.go +++ b/pkg/stores/sqlpartition/store.go @@ -7,14 +7,14 @@ import ( "context" "github.com/rancher/apiserver/pkg/types" - lassopartition "github.com/rancher/lasso/pkg/cache/sql/partition" "github.com/rancher/steve/pkg/accesscontrol" + cachepartition "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/stores/partition" ) // Partitioner is an interface for interacting with partitions. type Partitioner interface { - All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]lassopartition.Partition, error) + All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]cachepartition.Partition, error) Store() UnstructuredStore } diff --git a/pkg/stores/sqlpartition/store_test.go b/pkg/stores/sqlpartition/store_test.go index d1f11be7..d98a8188 100644 --- a/pkg/stores/sqlpartition/store_test.go +++ b/pkg/stores/sqlpartition/store_test.go @@ -15,7 +15,7 @@ import ( "go.uber.org/mock/gomock" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/stores/sqlproxy" "github.com/rancher/wrangler/v3/pkg/generic" diff --git a/pkg/stores/sqlproxy/proxy_mocks_test.go b/pkg/stores/sqlproxy/proxy_mocks_test.go index a5559d47..45ccd418 100644 --- a/pkg/stores/sqlproxy/proxy_mocks_test.go +++ b/pkg/stores/sqlproxy/proxy_mocks_test.go @@ -14,9 +14,9 @@ import ( reflect "reflect" types "github.com/rancher/apiserver/pkg/types" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - factory "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + factory "github.com/rancher/steve/pkg/sqlcache/informer/factory" + partition "github.com/rancher/steve/pkg/sqlcache/partition" summary "github.com/rancher/wrangler/v3/pkg/summary" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlproxy/proxy_store.go b/pkg/stores/sqlproxy/proxy_store.go index f6d1e34b..878405a0 100644 --- a/pkg/stores/sqlproxy/proxy_store.go +++ b/pkg/stores/sqlproxy/proxy_store.go @@ -32,9 +32,9 @@ import ( "github.com/rancher/apiserver/pkg/apierror" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/data" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/rancher/wrangler/v3/pkg/schemas/validation" @@ -333,7 +333,7 @@ func gvkKey(group, version, kind string) string { return group + "_" + version + "_" + kind } -// getFieldsFromSchema converts object field names from types.APISchema's format into lasso's +// getFieldsFromSchema converts object field names from types.APISchema's format into steve's // cache.sql.informer's slice format (e.g. "metadata.resourceVersion" is ["metadata", "resourceVersion"]) func getFieldsFromSchema(schema *types.APISchema) [][]string { var fields [][]string @@ -757,7 +757,7 @@ func (s *Store) ListByPartitions(apiOp *types.APIRequest, schema *types.APISchem list, total, continueToken, err := inf.ListByOptions(apiOp.Context(), opts, partitions, apiOp.Namespace) if err != nil { - if errors.Is(err, informer.InvalidColumnErr) { + if errors.Is(err, informer.ErrInvalidColumn) { return nil, 0, "", apierror.NewAPIError(validation.InvalidBodyContent, err.Error()) } return nil, 0, "", err diff --git a/pkg/stores/sqlproxy/proxy_store_test.go b/pkg/stores/sqlproxy/proxy_store_test.go index bdbbbaab..8690eea3 100644 --- a/pkg/stores/sqlproxy/proxy_store_test.go +++ b/pkg/stores/sqlproxy/proxy_store_test.go @@ -12,9 +12,9 @@ import ( "github.com/rancher/wrangler/v3/pkg/schemas/validation" apierrors "k8s.io/apimachinery/pkg/api/errors" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/attributes" "github.com/rancher/steve/pkg/resources/common" "github.com/rancher/steve/pkg/stores/sqlpartition/listprocessor" @@ -42,7 +42,7 @@ import ( ) //go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./proxy_mocks_test.go github.com/rancher/steve/pkg/stores/sqlproxy Cache,ClientGetter,CacheFactory,SchemaColumnSetter,RelationshipNotifier,TransformBuilder -//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister //go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface var c *watch.FakeWatcher diff --git a/pkg/stores/sqlproxy/sql_informer_mocks_test.go b/pkg/stores/sqlproxy/sql_informer_mocks_test.go index e8f5358c..125f2192 100644 --- a/pkg/stores/sqlproxy/sql_informer_mocks_test.go +++ b/pkg/stores/sqlproxy/sql_informer_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: ByOptionsLister) +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister // // Package sqlproxy is a generated GoMock package. @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" )