From 84b022b326698d5719bc4d503e20d598bd3b11bf Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Wed, 15 Jan 2025 23:10:05 -0500 Subject: [PATCH 01/10] Copy pkg/cache/sql from lasso to pkg/sqlcache --- pkg/sqlcache/Readme.md | 157 ++ pkg/sqlcache/db/client.go | 372 +++++ pkg/sqlcache/db/client_test.go | 667 ++++++++ pkg/sqlcache/db/db_mocks_test.go | 370 +++++ pkg/sqlcache/db/transaction/transaction.go | 92 + .../db/transaction/transaction_mocks_test.go | 184 ++ .../db/transaction/transaction_test.go | 182 ++ pkg/sqlcache/db/transaction_mocks_test.go | 184 ++ pkg/sqlcache/db/utility.go | 8 + pkg/sqlcache/encryption/encrypt.go | 168 ++ pkg/sqlcache/encryption/encrypt_test.go | 327 ++++ pkg/sqlcache/informer/db_mocks_test.go | 204 +++ pkg/sqlcache/informer/dynamic_mocks_test.go | 237 +++ .../informer/factory/db_mocks_test.go | 121 ++ .../informer/factory/dynamic_mocks_test.go | 237 +++ .../informer/factory/factory_mocks_test.go | 179 ++ .../informer/factory/informer_factory.go | 187 +++ .../informer/factory/informer_factory_test.go | 287 ++++ .../informer/factory/k8s_cache_mocks_test.go | 223 +++ pkg/sqlcache/informer/indexer.go | 264 +++ pkg/sqlcache/informer/indexer_test.go | 614 +++++++ pkg/sqlcache/informer/informer.go | 94 ++ pkg/sqlcache/informer/informer_mocks_test.go | 59 + pkg/sqlcache/informer/informer_test.go | 351 ++++ pkg/sqlcache/informer/listoption_indexer.go | 853 ++++++++++ .../informer/listoption_indexer_test.go | 1478 +++++++++++++++++ pkg/sqlcache/informer/listoptions.go | 68 + pkg/sqlcache/informer/shared_informer_hack.go | 22 + pkg/sqlcache/informer/shared_informer_test.go | 325 ++++ pkg/sqlcache/informer/sql_mocks_test.go | 347 ++++ pkg/sqlcache/informer/store_mocks_test.go | 165 ++ pkg/sqlcache/informer/tx_mocks_test.go | 99 ++ pkg/sqlcache/integration_test.go | 365 ++++ pkg/sqlcache/partition/partition.go | 24 + pkg/sqlcache/store/db_mocks_test.go | 204 +++ pkg/sqlcache/store/store.go | 360 ++++ pkg/sqlcache/store/store_mocks_test.go | 165 ++ pkg/sqlcache/store/store_test.go | 646 +++++++ pkg/sqlcache/store/tx_mocks_test.go | 99 ++ 39 files changed, 10988 insertions(+) create mode 100644 pkg/sqlcache/Readme.md create mode 100644 pkg/sqlcache/db/client.go create mode 100644 pkg/sqlcache/db/client_test.go create mode 100644 pkg/sqlcache/db/db_mocks_test.go create mode 100644 pkg/sqlcache/db/transaction/transaction.go create mode 100644 pkg/sqlcache/db/transaction/transaction_mocks_test.go create mode 100644 pkg/sqlcache/db/transaction/transaction_test.go create mode 100644 pkg/sqlcache/db/transaction_mocks_test.go create mode 100644 pkg/sqlcache/db/utility.go create mode 100644 pkg/sqlcache/encryption/encrypt.go create mode 100644 pkg/sqlcache/encryption/encrypt_test.go create mode 100644 pkg/sqlcache/informer/db_mocks_test.go create mode 100644 pkg/sqlcache/informer/dynamic_mocks_test.go create mode 100644 pkg/sqlcache/informer/factory/db_mocks_test.go create mode 100644 pkg/sqlcache/informer/factory/dynamic_mocks_test.go create mode 100644 pkg/sqlcache/informer/factory/factory_mocks_test.go create mode 100644 pkg/sqlcache/informer/factory/informer_factory.go create mode 100644 pkg/sqlcache/informer/factory/informer_factory_test.go create mode 100644 pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go create mode 100644 pkg/sqlcache/informer/indexer.go create mode 100644 pkg/sqlcache/informer/indexer_test.go create mode 100644 pkg/sqlcache/informer/informer.go create mode 100644 pkg/sqlcache/informer/informer_mocks_test.go create mode 100644 pkg/sqlcache/informer/informer_test.go create mode 100644 pkg/sqlcache/informer/listoption_indexer.go create mode 100644 pkg/sqlcache/informer/listoption_indexer_test.go create mode 100644 pkg/sqlcache/informer/listoptions.go create mode 100644 pkg/sqlcache/informer/shared_informer_hack.go create mode 100644 pkg/sqlcache/informer/shared_informer_test.go create mode 100644 pkg/sqlcache/informer/sql_mocks_test.go create mode 100644 pkg/sqlcache/informer/store_mocks_test.go create mode 100644 pkg/sqlcache/informer/tx_mocks_test.go create mode 100644 pkg/sqlcache/integration_test.go create mode 100644 pkg/sqlcache/partition/partition.go create mode 100644 pkg/sqlcache/store/db_mocks_test.go create mode 100644 pkg/sqlcache/store/store.go create mode 100644 pkg/sqlcache/store/store_mocks_test.go create mode 100644 pkg/sqlcache/store/store_test.go create mode 100644 pkg/sqlcache/store/tx_mocks_test.go diff --git a/pkg/sqlcache/Readme.md b/pkg/sqlcache/Readme.md new file mode 100644 index 00000000..a0c4ce66 --- /dev/null +++ b/pkg/sqlcache/Readme.md @@ -0,0 +1,157 @@ +# SQL Cache + +## Sections +- [ListOptions Informer](#listoptions-informer) + - [List Options](#list-options) + - [ListOption Indexer](#listoptions-indexer) + - [SQL Store](#sql-store) + - [Partitions](#partitions) +- [How to Use](#how-to-use) +- [Technical Information](#technical-information) + - [SQL Tables](#sql-tables) + - [SQLite Driver](#sqlite-driver) + - [Connection Pooling](#connection-pooling) + - [Encryption Defaults](#encryption-defaults) + - [Indexed Fields](#indexed-fields) + - [ListOptions Behavior](#listoptions-behavior) + - [Troubleshooting Sqlite](#troubleshooting-sqlite) + + + +## ListOptions Informer +The main usable feature from the SQL cache is the ListOptions Informer. The ListOptionsInformer provides listing functionality, +like any other informer, but with a wider array of options. The options are configured by informer.ListOptions. + +### List Options +ListOptions includes the following: +* Match filters for indexed fields. Filters are for specifying the value a given field in an object should be in order to +be included in the list. Filters can be set to equals or not equals. Filters can be set to look for partial matches or +exact (strict) matches. Filters can be OR'd and AND'd with one another. Filters only work on fields that have been indexed. +* Primary field and secondary field sorting order. Can choose up to two fields to sort on. Sort order can be ascending +or descending. Default sorting is to sort on metadata.namespace in ascending first and then sort on metadata.name. +* Page size to specify how many items to include in a response. +* Page number to specify offset. For example, a page size of 50 and a page number of 2, will return items starting at +index 50. Index will be dependent on sort. Page numbers start at 1. + +### ListOptions Factory +The ListOptions Factory helps manage multiple ListOption Informers. A user can call Factory.InformerFor(), to create new +ListOptions informers if they do not exist and retrieve existing ones. + +### ListOptions Indexer +Like all other informers, the ListOptions informer uses an indexer to cache objects of the informer's type. A few features +set the ListOptions Indexer apart from others indexers: +* an on-disk store instead of an in-memory store. +* accepts list options backed by SQL queries for extended search/filter/sorting capability. +* AES GCM encryption using key hierarchy. + +### SQL Store +The SQL store is the main interface for interacting with the database. This store backs the indexer, and provides all +functionality required by the cache.Store interface. + +### Partitions +Partitions are constraints for ListOptionsInform ListByOptions() method that are separate from ListOptions. Partitions +are strict conditions that dictate which namespaces or names can be searched from. These overrule ListOptions and are +intended to be used as a way of enforcing RBAC. + +## How to Use +```go + package main + import( + "k8s.io/client-go/dynamic" + "github.com/rancher/lasso/pkg/cache/sql/informer" + "github.com/rancher/lasso/pkg/cache/sql/informer/factory" + ) + + func main() { + cacheFactory, err := factory.NewCacheFactory() + if err != nil { + panic(err) + } + // config should be some rest config created from kubeconfig + // there are other ways to create a config and any client that conforms to k8s.io/client-go/dynamic.ResourceInterface + // will work. + client, err := dynamic.NewForConfig(config) + if err != nil { + panic(err) + } + + fields := [][]string{{"metadata", "name"}, {"metadata", "namespace"}} + opts := &informer.ListOptions{} + // gvk should be of type k8s.io/apimachinery/pkg/runtime/schema.GroupVersionKind + c, err := cacheFactory.CacheFor(fields, client, gvk) + if err != nil { + panic(err) + } + + // continueToken will just be an offset that can be used in Resume on a subsequent request to continue + // to next page + list, continueToken, err := c.ListByOptions(apiOp.Context(), opts, partitions, namespace) + if err != nil { + panic(err) + } + } +``` + +## Technical Information + +### SQL Tables +There are three tables that are created for the ListOption informer: +* object table - this contains objects, including all their fields, as blobs. These blobs may be encrypted. +* fields table - this contains specific fields of value for objects. These are specified on informer create and are fields +that it is desired to filter or order on. +* indices table - the indices table stores indexes created and objects' values for each index. This backs the generic indexer +that contains the functionality needed to conform to cache.Indexer. + +### SQLite Driver +There are multiple SQLite drivers that this package could have used. One of the most, if not the most, popular SQLite golang +drivers is [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3). This driver is not being used because it requires enabling +the cgo option when compiling and at the moment lasso's main consumer, rancher, does not compile with cgo. We did not want +the SQL informer to be the sole driver in switching to using cgo. Instead, modernc's driver which is in pure golang. Side-by-side +comparisons can be found indicating the cgo version is, as expected, more performant. If in the future it is deemed worthwhile +then the driver can be easily switched by replacing the empty import in `pkg/cache/sql/store` from `_ "modernc.org/sqlite"` to `_ "github.com/mattn/go-sqlite3"`. + +### Connection Pooling +While working with the `database/sql` package for go, it is important to understand how sql.Open() and other methods manage +connections. Open starts a connection pool; that is to say after calling open once, there may be anywhere from zero to many +connections attached to a sql.Connection. `database/sql` manages this connection pool under the hood. In most cases, an +application only need one sql.Connection, although sometimes application use two: one for writes, the other for reads. To +read more about the `sql` package's connection pooling read [Managing connections](https://go.dev/doc/database/manage-connections). + +The use of connection pooling and the fact that lasso potentially has many go routines accessing the same connection pool, +means we have to be careful with writes. Exclusively using sql transaction to write helps ensure safety. To read more about +sql transactions read SQLite's [Transaction docs](https://www.sqlite.org/lang_transaction.html). + +### Encryption Defaults +By default only specified types are encrypted. These types are hard-coded and defined by defaultEncryptedResourceTypes +in `pkg/cache/sql/informer/factory/informer_factory.go`. To enabled encryption for all types, set the ENV variable +`CATTLE_ENCRYPT_CACHE_ALL` to "true". + +The key size used is 256 bits. Data-encryption-keys are stored in the object table and are rotated every 150,000 writes. + +### Indexed Fields +Filtering and sorting only work on indexed fields. These fields are defined when using `CacheFor`. Objects will +have the following indexes by default: +* Fields in informer.defaultIndexedFields +* Fields passed to InformerFor() + +### ListOptions Behavior +Defaults: +* Sort.PrimaryField: `metadata.namespace` +* Sort.SecondaryField: `metadata.name` +* Sort.PrimaryOrder: `ASC` (ascending) +* Sort.SecondaryOrder: `ASC` (ascending) +* All filters have partial matching set to false by default + +There are some uncommon ways someone could use ListOptions where it would be difficult to predict what the result would be. +Below is a non-exhaustive list of some of these cases and what the behavior is: +* Setting Pagination.Page but not Pagination.PageSize will cause Page to be ignored +* Setting Sort.SecondaryField only will sort as though it was Sort.PrimaryField. Sort.SecondaryOrder will still be applied +and Sort.PrimaryOrder will be ignored + +### Writing Secure Queries +Values should be supplied to SQL queries using placeholders, read [Avoiding SQL Injection Risk](https://go.dev/doc/database/sql-injection). Any other portions +of a query that may be user supplied, such as columns, should be carefully validated against a fixed set of acceptable values. + +### Troubleshooting SQLite +A useful tool for troubleshooting the database files is the sqlite command line tool. Another useful tool is the goland +sqlite plugin. Both of these tools can be used with the database files. diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go new file mode 100644 index 00000000..441e2a2a --- /dev/null +++ b/pkg/sqlcache/db/client.go @@ -0,0 +1,372 @@ +/* +Package db offers client struct and functions to interact with database connection. It provides encrypting, decrypting, +and a way to reset the database. +*/ +package db + +import ( + "bytes" + "context" + "database/sql" + "encoding/gob" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + + "github.com/pkg/errors" + "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + _ "modernc.org/sqlite" +) + +const ( + // InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running lasso + InformerObjectCacheDBPath = "informer_object_cache.db" + + informerObjectCachePerms fs.FileMode = 0o600 +) + +// Client is a database client that provides encrypting, decrypting, and database resetting. +type Client struct { + conn Connection + connLock sync.RWMutex + encryptor Encryptor + decryptor Decryptor +} + +// Connection represents a connection pool. +type Connection interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) + Exec(query string, args ...any) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Close() error +} + +// Closable Closes an underlying connection and returns an error on failure. +type Closable interface { + Close() error +} + +// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them. +type Rows interface { + Next() bool + Err() error + Close() error + Scan(dest ...any) error +} + +// QueryError encapsulates an error while executing a query +type QueryError struct { + QueryString string + Err error +} + +// Error returns a string representation of this QueryError +func (e *QueryError) Error() string { + return "while executing query: " + e.QueryString + " got error: " + e.Err.Error() +} + +// Unwrap returns the underlying error +func (e *QueryError) Unwrap() error { + return e.Err +} + +// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed. +type TXClient interface { + StmtExec(stmt transaction.Stmt, args ...any) error + Exec(stmt string, args ...any) error + Commit() error + Stmt(stmt *sql.Stmt) transaction.Stmt + Cancel() error +} + +// Encryptor encrypts data with a key which is rotated to avoid wear-out. +type Encryptor interface { + // Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. + Encrypt([]byte) ([]byte, []byte, uint32, error) +} + +// Decryptor decrypts data previously encrypted by Encryptor. +type Decryptor interface { + // Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error. + Decrypt([]byte, []byte, uint32) ([]byte, error) +} + +// NewClient returns a Client. If the given connection is nil then a default one will be created. +func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) { + client := &Client{ + encryptor: encryptor, + decryptor: decryptor, + } + if c != nil { + client.conn = c + return client, nil + } + err := client.NewConnection() + if err != nil { + return nil, err + } + + return client, nil +} + +// Prepare prepares the given string into a sql statement on the client's connection. +func (c *Client) Prepare(stmt string) *sql.Stmt { + c.connLock.RLock() + defer c.connLock.RUnlock() + prepared, err := c.conn.Prepare(stmt) + if err != nil { + panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err)) + } + return prepared +} + +// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried +// given a sqlite busy error. +func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + return stmt.QueryContext(ctx, params...) +} + +// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant +// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection. +func (c *Client) CloseStmt(closable Closable) error { + return closable.Close() +} + +// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type, +// and returns a slice of those objects. +func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + var result []any + for rows.Next() { + data, err := c.decryptScan(rows, shouldDecrypt) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + singleResult, err := fromBytes(data, typ) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + result = append(result, singleResult.Elem().Interface()) + } + err := rows.Err() + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return nil, err + } + + return result, nil +} + +// ReadStrings scans the given rows into strings, and then returns the strings as a slice. +func (c *Client) ReadStrings(rows Rows) ([]string, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + var result []string + for rows.Next() { + var key string + err := rows.Scan(&key) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + result = append(result, key) + } + err := rows.Err() + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return nil, err + } + + return result, nil +} + +// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries) +func (c *Client) ReadInt(rows Rows) (int, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + if !rows.Next() { + return 0, closeRowsOnError(rows, sql.ErrNoRows) + } + + var result int + err := rows.Scan(&result) + if err != nil { + return 0, closeRowsOnError(rows, err) + } + + err = rows.Err() + if err != nil { + return 0, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return 0, err + } + + return result, nil +} + +// BeginTx attempts to begin a transaction. +// If forWriting is true, this method blocks until all other concurrent forWriting +// transactions have either committed or rolled back. +// If forWriting is false, it is assumed the returned transaction will exclusively +// be used for DQL (e.g. SELECT) queries. +// Not respecting the above rule might result in transactions failing with unexpected +// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked"). +// See discussion in https://github.com/rancher/lasso/pull/98 for details +func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + // note: this assumes _txlock=immediate in the connection string, see NewConnection + sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: !forWriting, + }) + if err != nil { + return nil, err + } + return transaction.NewClient(sqlTx), nil +} + +func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) { + var data, dataNonce sql.RawBytes + var kid uint32 + err := rows.Scan(&data, &dataNonce, &kid) + if err != nil { + return nil, err + } + if c.decryptor != nil && shouldDecrypt { + decryptedData, err := c.decryptor.Decrypt(data, dataNonce, kid) + if err != nil { + return nil, err + } + return decryptedData, nil + } + return data, nil +} + +// Upsert used to be called upsertEncrypted in store package before move +func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error { + objBytes := toBytes(obj) + var dataNonce []byte + var err error + var kid uint32 + if c.encryptor != nil && shouldEncrypt { + objBytes, dataNonce, kid, err = c.encryptor.Encrypt(objBytes) + if err != nil { + return err + } + } + + return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid) +} + +// toBytes encodes an object to a byte slice +func toBytes(obj any) []byte { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(obj) + if err != nil { + panic(fmt.Errorf("error while gobbing object: %w", err)) + } + bb := buf.Bytes() + return bb +} + +// fromBytes decodes an object from a byte slice +func fromBytes(buf sql.RawBytes, typ reflect.Type) (reflect.Value, error) { + dec := gob.NewDecoder(bytes.NewReader(buf)) + singleResult := reflect.New(typ) + err := dec.DecodeValue(singleResult) + return singleResult, err +} + +// closeRowsOnError closes the sql.Rows object and wraps errors if needed +func closeRowsOnError(rows Rows, err error) error { + ce := rows.Close() + if ce != nil { + return fmt.Errorf("error in closing rows while handling %s: %w", err.Error(), ce) + } + + return err +} + +// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently +// creates new files. +func (c *Client) NewConnection() error { + c.connLock.Lock() + defer c.connLock.Unlock() + if c.conn != nil { + err := c.conn.Close() + if err != nil { + return err + } + } + err := os.RemoveAll(InformerObjectCacheDBPath) + if err != nil { + return err + } + + // Set the permissions in advance, because we can't control them if + // the file is created by a sql.Open call instead. + if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil { + return nil + } + + sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+ + // open SQLite file in read-write mode, creating it if it does not exist + "mode=rwc&"+ + // use the WAL journal mode for consistency and efficiency + "_pragma=journal_mode=wal&"+ + // do not even attempt to attain durability. Database is thrown away at pod restart + "_pragma=synchronous=off&"+ + // do check foreign keys and honor ON DELETE CASCADE + "_pragma=foreign_keys=on&"+ + // if two transactions want to write at the same time, allow 2 minutes for the first to complete + // before baling out + "_pragma=busy_timeout=120000&"+ + // default to IMMEDIATE mode for transactions. Setting this parameter is the only current way + // to be able to switch between DEFERRED and IMMEDIATE modes in modernc.org/sqlite's implementation + // of BeginTx + "_txlock=immediate") + if err != nil { + return err + } + + c.conn = sqlDB + return nil +} + +// This acts like "touch" for both existing files and non-existing files. +// permissions. +// +// It's created with the correct perms, and if the file already exists, it will +// be chmodded to the correct perms. +func touchFile(filename string, perms fs.FileMode) error { + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, perms) + if err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + + return os.Chmod(filename, perms) +} diff --git a/pkg/sqlcache/db/client_test.go b/pkg/sqlcache/db/client_test.go new file mode 100644 index 00000000..8adf74c3 --- /dev/null +++ b/pkg/sqlcache/db/client_test.go @@ -0,0 +1,667 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "io/fs" + "math" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +// Mocks for this test are generated with the following command. +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db Rows,Connection,Encryptor,Decryptor,TXClient +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx + +type testStoreObject struct { + Id string + Val string +} + +func TestNewClient(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + expectedClient := &Client{ + conn: c, + encryptor: e, + decryptor: d, + } + client, err := NewClient(c, e, d) + assert.Nil(t, err) + assert.Equal(t, expectedClient, client) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryForRows(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + client := SetupClient(t, c, nil, nil) + s := NewMockStmt(gomock.NewController(t)) + ctx := context.TODO() + r := &sql.Rows{} + s.EXPECT().QueryContext(ctx).Return(r, nil) + rows, err := client.QueryForRows(ctx, s) + assert.Nil(t, err) + assert.Equal(t, r, rows) + }, + }) + tests = append(tests, testCase{description: "Query rows with params, QueryContext() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + client := SetupClient(t, c, nil, nil) + s := NewMockStmt(gomock.NewController(t)) + ctx := context.TODO() + s.EXPECT().QueryContext(ctx).Return(nil, fmt.Errorf("error")) + _, err := client.QueryForRows(ctx, s) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryObjects(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + var keyId uint32 = math.MaxUint32 + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query objects, with one row, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes(testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.Nil(t, err) + assert.Equal(t, 1, len(items)) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a decrypt error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes( + testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(nil, fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a Scan() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes(testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with no rows, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.Nil(t, err) + assert.Equal(t, 0, len(items)) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryStrings(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "ReadStrings(), with one row, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadStrings(r) + assert.Nil(t, err) + assert.Equal(t, 1, len(items)) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and Scan error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with one row, and Err() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with one row, and Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with no rows, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadStrings(r) + assert.Nil(t, err) + assert.Equal(t, 0, len(items)) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestReadInt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testResult := 42 + tests = append(tests, testCase{description: "One row, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + p := a[0].(*int) + *p = testResult + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + result, err := client.ReadInt(r) + assert.Nil(t, err) + assert.Equal(t, 42, result) + }, + }) + tests = append(tests, testCase{description: "One row, Scan error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "One row, Err() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + a[0] = testResult + }) + r.EXPECT().Err().Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "One row, Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + a[0] = testResult + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "No rows error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.ErrorIs(t, err, sql.ErrNoRows) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestBegin(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + sqlTx := &sql.Tx{} + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil) + client := SetupClient(t, c, e, d) + txC, err := client.BeginTx(context.Background(), false) + assert.Nil(t, err) + assert.NotNil(t, txC) + }, + }) + tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + sqlTx := &sql.Tx{} + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil) + client := SetupClient(t, c, e, d) + txC, err := client.BeginTx(context.Background(), true) + assert.Nil(t, err) + assert.NotNil(t, txC) + }, + }) + tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.BeginTx(context.Background(), false) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestUpsert(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + var keyID uint32 = 5 + + // Tests with shouldEncryptSet to true + tests = append(tests, testCase{description: "Upsert() with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + testObjBytes := toBytes(testObject) + testByteValue := []byte("something") + e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with Encrypt() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + testObjBytes := toBytes(testObject) + e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error")) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with StmtExec() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + testObjBytes := toBytes(testObject) + testByteValue := []byte("something") + e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error")) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with no errors and shouldEncrypt false", test: func(t *testing.T) { + c := SetupMockConnection(t) + d := SetupMockDecryptor(t) + e := SetupMockEncryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + var testByteValue []byte + testObjBytes := toBytes(testObject) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, false) + assert.Nil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestPrepare(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "Prepare() with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + sqlStmt := &sql.Stmt{} + c.EXPECT().Prepare("something").Return(sqlStmt, nil) + + stmt := client.Prepare("something") + assert.Equal(t, sqlStmt, stmt) + }, + }) + tests = append(tests, testCase{description: "Prepare() with Connection Prepare() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + c.EXPECT().Prepare("something").Return(nil, fmt.Errorf("error")) + + assert.Panics(t, func() { client.Prepare("something") }) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestNewConnection(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "NewConnection replaces file", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + c.EXPECT().Close().Return(nil) + + err := client.NewConnection() + assert.Nil(t, err) + + // Create a transaction to ensure that the file is written to disk. + txC, err := client.BeginTx(context.Background(), false) + assert.NoError(t, err) + assert.NoError(t, txC.Commit()) + + assert.FileExists(t, InformerObjectCacheDBPath) + assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600) + + err = os.Remove(InformerObjectCacheDBPath) + if err != nil { + assert.Fail(t, "could not remove object cache path after test") + } + }, + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestCommit(t *testing.T) { + +} + +func TestRollback(t *testing.T) { + +} + +func SetupMockConnection(t *testing.T) *MockConnection { + mockC := NewMockConnection(gomock.NewController(t)) + return mockC +} + +func SetupMockEncryptor(t *testing.T) *MockEncryptor { + mockE := NewMockEncryptor(gomock.NewController(t)) + return mockE +} + +func SetupMockDecryptor(t *testing.T) *MockDecryptor { + MockD := NewMockDecryptor(gomock.NewController(t)) + return MockD +} + +func SetupMockRows(t *testing.T) *MockRows { + MockR := NewMockRows(gomock.NewController(t)) + return MockR +} + +func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client { + c, _ := NewClient(connection, encryptor, decryptor) + return c +} + +func TestTouchFile(t *testing.T) { + t.Run("File doesn't exist before", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "test1.txt") + assert.NoError(t, touchFile(filename, 0600)) + assertFileHasPermissions(t, filename, 0600) + }) + + t.Run("File exists with different permissions", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "test2.txt") + assert.NoError(t, os.WriteFile(filename, []byte("test"), 0644)) + assert.NoError(t, touchFile(filename, 0600)) + assertFileHasPermissions(t, filename, 0600) + }) +} + +func assertFileHasPermissions(t *testing.T, fname string, wantPerms fs.FileMode) bool { + t.Helper() + info, err := os.Lstat(fname) + if err != nil { + if os.IsNotExist(err) { + return assert.Fail(t, fmt.Sprintf("unable to find file %q", fname)) + } + return assert.Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", fname, err)) + } + + // Stringifying the perms makes it easier to read than a Hex comparison. + assert.Equal(t, wantPerms.String(), info.Mode().Perm().String()) + + return true +} diff --git a/pkg/sqlcache/db/db_mocks_test.go b/pkg/sqlcache/db/db_mocks_test.go new file mode 100644 index 00000000..55580ee2 --- /dev/null +++ b/pkg/sqlcache/db/db_mocks_test.go @@ -0,0 +1,370 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db Rows,Connection,Encryptor,Decryptor,TXClient +// + +// Package db is a generated GoMock package. +package db + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} + +// MockConnection is a mock of Connection interface. +type MockConnection struct { + ctrl *gomock.Controller + recorder *MockConnectionMockRecorder +} + +// MockConnectionMockRecorder is the mock recorder for MockConnection. +type MockConnectionMockRecorder struct { + mock *MockConnection +} + +// NewMockConnection creates a new mock instance. +func NewMockConnection(ctrl *gomock.Controller) *MockConnection { + mock := &MockConnection{ctrl: ctrl} + mock.recorder = &MockConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockConnection) BeginTx(arg0 context.Context, arg1 *sql.TxOptions) (*sql.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(*sql.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockConnectionMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockConnection)(nil).BeginTx), arg0, arg1) +} + +// Close mocks base method. +func (m *MockConnection) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close)) +} + +// Exec mocks base method. +func (m *MockConnection) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockConnectionMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...) +} + +// Prepare mocks base method. +func (m *MockConnection) Prepare(arg0 string) (*sql.Stmt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockConnectionMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockConnection)(nil).Prepare), arg0) +} + +// MockEncryptor is a mock of Encryptor interface. +type MockEncryptor struct { + ctrl *gomock.Controller + recorder *MockEncryptorMockRecorder +} + +// MockEncryptorMockRecorder is the mock recorder for MockEncryptor. +type MockEncryptorMockRecorder struct { + mock *MockEncryptor +} + +// NewMockEncryptor creates a new mock instance. +func NewMockEncryptor(ctrl *gomock.Controller) *MockEncryptor { + mock := &MockEncryptor{ctrl: ctrl} + mock.recorder = &MockEncryptorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEncryptor) EXPECT() *MockEncryptorMockRecorder { + return m.recorder +} + +// Encrypt mocks base method. +func (m *MockEncryptor) Encrypt(arg0 []byte) ([]byte, []byte, uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(uint32) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockEncryptorMockRecorder) Encrypt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptor)(nil).Encrypt), arg0) +} + +// MockDecryptor is a mock of Decryptor interface. +type MockDecryptor struct { + ctrl *gomock.Controller + recorder *MockDecryptorMockRecorder +} + +// MockDecryptorMockRecorder is the mock recorder for MockDecryptor. +type MockDecryptorMockRecorder struct { + mock *MockDecryptor +} + +// NewMockDecryptor creates a new mock instance. +func NewMockDecryptor(ctrl *gomock.Controller) *MockDecryptor { + mock := &MockDecryptor{ctrl: ctrl} + mock.recorder = &MockDecryptorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDecryptor) EXPECT() *MockDecryptorMockRecorder { + return m.recorder +} + +// Decrypt mocks base method. +func (m *MockDecryptor) Decrypt(arg0, arg1 []byte, arg2 uint32) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2) +} + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} diff --git a/pkg/sqlcache/db/transaction/transaction.go b/pkg/sqlcache/db/transaction/transaction.go new file mode 100644 index 00000000..83609143 --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction.go @@ -0,0 +1,92 @@ +/* +Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures. +The use of this package simplifies testing for callers by making the underlying transaction mock-able. +*/ +package transaction + +import ( + "context" + "database/sql" + "github.com/sirupsen/logrus" + + "github.com/pkg/errors" +) + +// Client provides a way to interact with the underlying sql transaction. +type Client struct { + sqlTx SQLTx +} + +// SQLTx represents a sql transaction +type SQLTx interface { + Exec(query string, args ...any) (sql.Result, error) + Stmt(stmt *sql.Stmt) *sql.Stmt + Commit() error + Rollback() error +} + +// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type +// because we are able to mock its outputs and do not need an actual connection. +type Stmt interface { + Exec(args ...any) (sql.Result, error) + Query(args ...any) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) +} + +// NewClient returns a Client with the given transaction assigned. +func NewClient(tx SQLTx) *Client { + return &Client{sqlTx: tx} +} + +// Commit commits the transaction and then unlocks the database. +func (c *Client) Commit() error { + return c.sqlTx.Commit() +} + +// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec() +// returns an error. +func (c *Client) Exec(stmt string, args ...any) error { + _, err := c.sqlTx.Exec(stmt, args...) + if err != nil { + return c.rollback(c.sqlTx, err) + } + return nil +} + +// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned +// here to aid in testing callers by providing a way to configure the statement's behavior. +func (c *Client) Stmt(stmt *sql.Stmt) Stmt { + s := c.sqlTx.Stmt(stmt) + return s +} + +// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The +// transaction is rolled back if Stmt.Exec() returns an error. +func (c *Client) StmtExec(stmt Stmt, args ...any) error { + _, err := stmt.Exec(args...) + if err != nil { + logrus.Debugf("StmtExec failed: query %s, args: %s, err: %s", stmt, args, err) + return c.rollback(c.sqlTx, err) + } + return nil +} + +// rollback handles rollbacks and wraps errors if needed +func (c *Client) rollback(tx SQLTx, err error) error { + rerr := tx.Rollback() + if rerr != nil { + return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr) + } + return errors.Wrapf(err, "Encountered error, successfully rolled back") +} + +// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned +// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too +// late. +func (c *Client) Cancel() error { + rerr := c.sqlTx.Rollback() + if rerr != sql.ErrTxDone { + return rerr + } + return nil +} diff --git a/pkg/sqlcache/db/transaction/transaction_mocks_test.go b/pkg/sqlcache/db/transaction/transaction_mocks_test.go new file mode 100644 index 00000000..3bd82287 --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction_mocks_test.go @@ -0,0 +1,184 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt,SQLTx) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx +// + +// Package transaction is a generated GoMock package. +package transaction + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} + +// MockSQLTx is a mock of SQLTx interface. +type MockSQLTx struct { + ctrl *gomock.Controller + recorder *MockSQLTxMockRecorder +} + +// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. +type MockSQLTxMockRecorder struct { + mock *MockSQLTx +} + +// NewMockSQLTx creates a new mock instance. +func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { + mock := &MockSQLTx{ctrl: ctrl} + mock.recorder = &MockSQLTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockSQLTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) +} + +// Rollback mocks base method. +func (m *MockSQLTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) +} + +// Stmt mocks base method. +func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/db/transaction/transaction_test.go b/pkg/sqlcache/db/transaction/transaction_test.go new file mode 100644 index 00000000..aada33a7 --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction_test.go @@ -0,0 +1,182 @@ +package transaction + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx + +func TestNewClient(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + c := NewClient(tx) + assert.Equal(t, tx, c.sqlTx) +} + +func TestCommit(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + tx.EXPECT().Commit().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.Commit() + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + tx.EXPECT().Commit().Return(fmt.Errorf("error")) + c := &Client{ + sqlTx: tx, + } + err := c.Commit() + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestExec(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(fmt.Errorf("error")) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestStmt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := &sql.Stmt{} + var returnedTXStmt *sql.Stmt + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Stmt(stmt).Return(returnedTXStmt) + c := &Client{ + sqlTx: tx, + } + returnedStmt := c.Stmt(stmt) + // whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it + // won't be equal to an empty struct + assert.Equal(t, returnedTXStmt, returnedStmt) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestStmtExec(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, nil) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(fmt.Errorf("error2")) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/db/transaction_mocks_test.go b/pkg/sqlcache/db/transaction_mocks_test.go new file mode 100644 index 00000000..1cc9c874 --- /dev/null +++ b/pkg/sqlcache/db/transaction_mocks_test.go @@ -0,0 +1,184 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt,SQLTx) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx +// + +// Package db is a generated GoMock package. +package db + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} + +// MockSQLTx is a mock of SQLTx interface. +type MockSQLTx struct { + ctrl *gomock.Controller + recorder *MockSQLTxMockRecorder +} + +// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. +type MockSQLTxMockRecorder struct { + mock *MockSQLTx +} + +// NewMockSQLTx creates a new mock instance. +func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { + mock := &MockSQLTx{ctrl: ctrl} + mock.recorder = &MockSQLTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockSQLTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) +} + +// Rollback mocks base method. +func (m *MockSQLTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) +} + +// Stmt mocks base method. +func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/db/utility.go b/pkg/sqlcache/db/utility.go new file mode 100644 index 00000000..a8f84d29 --- /dev/null +++ b/pkg/sqlcache/db/utility.go @@ -0,0 +1,8 @@ +package db + +import "strings" + +// Sanitize returns a string that can be used in SQL as a name +func Sanitize(s string) string { + return strings.ReplaceAll(s, "\"", "") +} diff --git a/pkg/sqlcache/encryption/encrypt.go b/pkg/sqlcache/encryption/encrypt.go new file mode 100644 index 00000000..a7783ac9 --- /dev/null +++ b/pkg/sqlcache/encryption/encrypt.go @@ -0,0 +1,168 @@ +/* +Package encryption provides encryption and decryption functions, while +abstracting away key management concerns. +Uses AES-GCM encryption, with key rotation, keeping keys in memory. +*/ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "sync" + + "github.com/pkg/errors" +) + +var ( + ErrKeyNotFound = errors.New("data key not found") + // maxWriteCount holds the maximum amount of times the active key can be + // used, prior to it being rotated. 2^32 is the currently recommended key + // wear-out params by NIST for AES-GCM using random nonces. + maxWriteCount int64 = 1 << 32 +) + +const ( + keySize = 32 // 32 for AES-256 +) + +// Manager uses AES-GCM encryption and keeps in memory the data encryption +// keys. The active encryption key is automatically rotated once it has been +// used over a certain amount of times - defined by maxWriteCount. +type Manager struct { + dataKeys [][]byte + activeKeyCounter int64 + + // lock works as the mutual exclusion lock for dataKeys. + lock sync.RWMutex + // counterLock works as the mutual exclusion lock for activeKeyCounter. + counterLock sync.Mutex +} + +// NewManager returns Manager, which satisfies db.Encryptor and db.Decryptor +func NewManager() (*Manager, error) { + m := &Manager{ + dataKeys: [][]byte{}, + } + m.newDataEncryptionKey() + + return m, nil +} + +// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. +func (m *Manager) Encrypt(data []byte) ([]byte, []byte, uint32, error) { + dek, keyID, err := m.fetchActiveDataKey() + if err != nil { + return nil, nil, 0, err + } + aead, err := createGCMCypher(dek) + if err != nil { + return nil, nil, 0, err + } + edata, nonce, err := encrypt(aead, data) + if err != nil { + return nil, nil, 0, err + } + return edata, nonce, keyID, nil +} + +// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error. +func (m *Manager) Decrypt(edata, nonce []byte, keyID uint32) ([]byte, error) { + dek, err := m.key(keyID) + if err != nil { + return nil, err + } + + aead, err := createGCMCypher(dek) + if err != nil { + return nil, errors.Wrap(err, "failed to create GCMCypher from DEK") + } + data, err := aead.Open(nil, nonce, edata, nil) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to decrypt data using keyid %d", keyID)) + } + return data, nil +} + +func encrypt(aead cipher.AEAD, data []byte) ([]byte, []byte, error) { + if aead == nil { + return nil, nil, fmt.Errorf("aead is nil, cannot encrypt data") + } + nonce := make([]byte, aead.NonceSize()) + _, err := rand.Read(nonce) + if err != nil { + return nil, nil, err + } + sealed := aead.Seal(nil, nonce, data, nil) + return sealed, nonce, nil +} + +func createGCMCypher(key []byte) (cipher.AEAD, error) { + b, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(b) + if err != nil { + return nil, err + } + return aead, nil +} + +// fetchActiveDataKey returns the current data key and its key ID. +// Each call results in activeKeyCounter being incremented by 1. When the +// the activeKeyCounter exceeds maxWriteCount, the active data key is +// rotated - before being returned. +func (m *Manager) fetchActiveDataKey() ([]byte, uint32, error) { + m.counterLock.Lock() + defer m.counterLock.Unlock() + + m.activeKeyCounter++ + if m.activeKeyCounter >= maxWriteCount { + return m.newDataEncryptionKey() + } + + return m.activeKey() +} + +func (m *Manager) newDataEncryptionKey() ([]byte, uint32, error) { + dek := make([]byte, keySize) + _, err := rand.Read(dek) + if err != nil { + return nil, 0, err + } + + m.lock.Lock() + defer m.lock.Unlock() + + m.activeKeyCounter = 1 + + m.dataKeys = append(m.dataKeys, dek) + keyID := uint32(len(m.dataKeys) - 1) + + return dek, keyID, nil +} + +func (m *Manager) activeKey() ([]byte, uint32, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + nk := len(m.dataKeys) + if nk == 0 { + return nil, 0, ErrKeyNotFound + } + keyID := uint32(nk - 1) + + return m.dataKeys[keyID], keyID, nil +} + +func (m *Manager) key(keyID uint32) ([]byte, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.dataKeys) <= int(keyID) { + return nil, fmt.Errorf("%w: %v", ErrKeyNotFound, keyID) + } + return m.dataKeys[keyID], nil +} diff --git a/pkg/sqlcache/encryption/encrypt_test.go b/pkg/sqlcache/encryption/encrypt_test.go new file mode 100644 index 00000000..46f3300a --- /dev/null +++ b/pkg/sqlcache/encryption/encrypt_test.go @@ -0,0 +1,327 @@ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewManager(t *testing.T) { + m, err := NewManager() + if err != nil { + t.FailNow() + } + assert.NotNil(t, m) +} + +func TestEncrypt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + var tests []testCase + + tests = append(tests, testCase{description: "test encrypt with arbitrary initial key", test: func(t *testing.T) { + testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181} + + m, err := NewManager() + require.Nil(t, err) + + m.dataKeys[0] = testDEK + + testData := []byte("something") + cipherText, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + dek := m.dataKeys[keyID] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + decryptedData, err := aead.Open(nil, nonce, cipherText, nil) + + require.Nil(t, err) + assert.Equal(t, testData, decryptedData) + }}) + tests = append(tests, testCase{description: "test encrypt without arbitrary initial key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipherText, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + dek := m.dataKeys[keyID] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + decryptedData, err := aead.Open(nil, nonce, cipherText, nil) + + require.Nil(t, err) + assert.Equal(t, testData, decryptedData) + }}) + tests = append(tests, testCase{description: "test encrypt: same data yield different cipher/nonce pair", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipher1, nonce1, keyID1, err := m.Encrypt(testData) + require.Nil(t, err) + assert.Len(t, cipher1, 25) + assert.Len(t, nonce1, 12) + assert.NotEmpty(t, cipher1) + assert.NotEmpty(t, nonce1) + + cipher2, nonce2, keyID2, err := m.Encrypt(testData) + require.Nil(t, err) + + assert.Equal(t, keyID1, keyID2) + assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher") + assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce") + }}) + tests = append(tests, testCase{description: "test encrypt with key rotation", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipher1, nonce1, keyID1, err := m.Encrypt(testData) + require.Nil(t, err) + assert.Len(t, cipher1, 25) + assert.Len(t, nonce1, 12) + assert.NotEmpty(t, cipher1) + assert.NotEmpty(t, nonce1) + + m.activeKeyCounter += maxWriteCount + + cipher2, nonce2, keyID2, err := m.Encrypt(testData) + require.Nil(t, err) + + assert.Equal(t, int64(1), m.activeKeyCounter) + assert.NotEqual(t, keyID1, keyID2) + assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher") + assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce") + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestDecrypt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + var tests []testCase + + tests = append(tests, testCase{description: "test decrypt with arbitrary key", test: func(t *testing.T) { + testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181} + + m, err := NewManager() + require.Nil(t, err) + + m.dataKeys[0] = testDEK + + testData := []byte("something") + + // encrypt data out of band. + b, err := aes.NewCipher(testDEK) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipherText, nonce, 0) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + tests = append(tests, testCase{description: "test decrypt without arbitrary key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipherText, nonce, 0) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + tests = append(tests, testCase{description: "test decrypt with wrong data nonce should return error", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // generate random nonce. + randomNonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + // decrypted encrypted data using encrypted dek + _, err = m.Decrypt(cipherText, randomNonce, 0) + assert.NotNil(t, err) + }, + }) + + tests = append(tests, testCase{description: "test decrypt with DEK/nonce pair not used to encrypt should return error", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + key, id, err := m.newDataEncryptionKey() + require.Nil(t, err) + m.dataKeys[id] = key + + plainText, err := m.Decrypt(cipherText, nonce, id) + assert.NotNil(t, err) + assert.Nil(t, plainText) + }, + }) + tests = append(tests, testCase{description: "test decrypt for non active key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + cipher, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + // force key rotation. + m.activeKeyCounter += maxWriteCount + _, _, newKeyID, err := m.Encrypt(nil) + require.Nil(t, err) + require.NotEqual(t, keyID, newKeyID) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipher, nonce, keyID) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +var buf = make([]byte, 8192) + +func BenchmarkEncryption(b *testing.B) { + benchEncrypt(b, 1024) + benchEncrypt(b, 4096) + benchEncrypt(b, 8192) +} + +func BenchmarkDecryption(b *testing.B) { + benchDecrypt(b, 1024) + benchDecrypt(b, 4096) + benchDecrypt(b, 8192) +} + +func benchEncrypt(b *testing.B, size int) { + m, err := NewManager() + if err != nil { + b.Fatal("failed to create manager", err) + } + // disable auto rotation to avoid skewing results. + maxWriteCount = math.MaxInt32 + + b.Run(fmt.Sprintf("encrypt-%d", size), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + _, _, _, err := m.Encrypt(buf[:size]) + if err != nil { + b.Fatal("error encrypting data", err) + } + } + }) +} + +func benchDecrypt(b *testing.B, size int) { + m, err := NewManager() + if err != nil { + b.Fatal("failed to create manager", err) + } + + edata, enonce, kid, err := m.Encrypt(buf[:size]) + if err != nil { + b.Fatal("failed to encrypt data", err) + } + + b.Run(fmt.Sprintf("decrypt-%d", size), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + _, err := m.Decrypt(edata, enonce, kid) + if err != nil { + b.Fatal("error encrypting data", err) + } + } + }) +} diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go new file mode 100644 index 00000000..3731e3d0 --- /dev/null +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: TXClient,Rows) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows +// + +// Package informer is a generated GoMock package. +package informer + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} diff --git a/pkg/sqlcache/informer/dynamic_mocks_test.go b/pkg/sqlcache/informer/dynamic_mocks_test.go new file mode 100644 index 00000000..07e169c1 --- /dev/null +++ b/pkg/sqlcache/informer/dynamic_mocks_test.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" +) + +// MockResourceInterface is a mock of ResourceInterface interface. +type MockResourceInterface struct { + ctrl *gomock.Controller + recorder *MockResourceInterfaceMockRecorder +} + +// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface. +type MockResourceInterfaceMockRecorder struct { + mock *MockResourceInterface +} + +// NewMockResourceInterface creates a new mock instance. +func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface { + mock := &MockResourceInterface{ctrl: ctrl} + mock.recorder = &MockResourceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Apply", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...) +} + +// ApplyStatus mocks base method. +func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyStatus indicates an expected call of ApplyStatus. +func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3) +} + +// Create mocks base method. +func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...) +} + +// Delete mocks base method. +func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...) +} + +// DeleteCollection mocks base method. +func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...) +} + +// List mocks base method. +func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...) +} + +// UpdateStatus mocks base method. +func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateStatus indicates an expected call of UpdateStatus. +func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go new file mode 100644 index 00000000..fd5fa071 --- /dev/null +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -0,0 +1,121 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: TXClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient +// + +// Package factory is a generated GoMock package. +package factory + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} diff --git a/pkg/sqlcache/informer/factory/dynamic_mocks_test.go b/pkg/sqlcache/informer/factory/dynamic_mocks_test.go new file mode 100644 index 00000000..29e2c0fd --- /dev/null +++ b/pkg/sqlcache/informer/factory/dynamic_mocks_test.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +// + +// Package factory is a generated GoMock package. +package factory + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" +) + +// MockResourceInterface is a mock of ResourceInterface interface. +type MockResourceInterface struct { + ctrl *gomock.Controller + recorder *MockResourceInterfaceMockRecorder +} + +// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface. +type MockResourceInterfaceMockRecorder struct { + mock *MockResourceInterface +} + +// NewMockResourceInterface creates a new mock instance. +func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface { + mock := &MockResourceInterface{ctrl: ctrl} + mock.recorder = &MockResourceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Apply", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...) +} + +// ApplyStatus mocks base method. +func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyStatus indicates an expected call of ApplyStatus. +func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3) +} + +// Create mocks base method. +func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...) +} + +// Delete mocks base method. +func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...) +} + +// DeleteCollection mocks base method. +func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...) +} + +// List mocks base method. +func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...) +} + +// UpdateStatus mocks base method. +func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateStatus indicates an expected call of UpdateStatus. +func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/sqlcache/informer/factory/factory_mocks_test.go b/pkg/sqlcache/informer/factory/factory_mocks_test.go new file mode 100644 index 00000000..fa5d4739 --- /dev/null +++ b/pkg/sqlcache/informer/factory/factory_mocks_test.go @@ -0,0 +1,179 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/informer/factory (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer/factory DBClient +// + +// Package factory is a generated GoMock package. +package factory + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/lasso/pkg/cache/sql/db" + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// NewConnection mocks base method. +func (m *MockDBClient) NewConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// NewConnection indicates an expected call of NewConnection. +func (mr *MockDBClientMockRecorder) NewConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockDBClient)(nil).NewConnection)) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go new file mode 100644 index 00000000..5ce85347 --- /dev/null +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -0,0 +1,187 @@ +/* +Package factory provides a cache factory for the sql-based cache. +*/ +package factory + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/rancher/lasso/pkg/cache/sql/db" + "github.com/rancher/lasso/pkg/cache/sql/encryption" + "github.com/rancher/lasso/pkg/cache/sql/informer" + sqlStore "github.com/rancher/lasso/pkg/cache/sql/store" + "github.com/rancher/lasso/pkg/log" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +// EncryptAllEnvVar is set to "true" if users want all types' data blobs to be encrypted in SQLite +// otherwise only variables in defaultEncryptedResourceTypes will have their blobs encrypted +const EncryptAllEnvVar = "CATTLE_ENCRYPT_CACHE_ALL" + +// CacheFactory builds Informer instances and keeps a cache of instances it created +type CacheFactory struct { + wg wait.Group + dbClient DBClient + stopCh chan struct{} + mutex sync.RWMutex + encryptAll bool + + newInformer newInformer + + informers map[schema.GroupVersionKind]*guardedInformer + informersMutex sync.Mutex +} + +type guardedInformer struct { + informer *informer.Informer + mutex *sync.Mutex +} + +type newInformer func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespace bool) (*informer.Informer, error) + +type DBClient interface { + informer.DBClient + sqlStore.DBClient + connector +} + +type Cache struct { + informer.ByOptionsLister +} + +type connector interface { + NewConnection() error +} + +var defaultEncryptedResourceTypes = map[schema.GroupVersionKind]struct{}{ + { + Version: "v1", + Kind: "Secret", + }: {}, +} + +// NewCacheFactory returns an informer factory instance +// This is currently called from steve via initial calls to `s.cacheFactory.CacheFor(...)` +func NewCacheFactory() (*CacheFactory, error) { + m, err := encryption.NewManager() + if err != nil { + return nil, err + } + dbClient, err := db.NewClient(nil, m, m) + if err != nil { + return nil, err + } + return &CacheFactory{ + wg: wait.Group{}, + stopCh: make(chan struct{}), + encryptAll: os.Getenv(EncryptAllEnvVar) == "true", + dbClient: dbClient, + newInformer: informer.NewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + }, nil +} + +// CacheFor returns an informer for given GVK, using sql store indexed with fields, using the specified client. For virtual fields, they must be added by the transform function +// and specified by fields to be used for later fields. +func (f *CacheFactory) CacheFor(fields [][]string, transform cache.TransformFunc, client dynamic.ResourceInterface, gvk schema.GroupVersionKind, namespaced bool, watchable bool) (Cache, error) { + // First of all block Reset() until we are done + f.mutex.RLock() + defer f.mutex.RUnlock() + + // Second, check if the informer and its accompanying informer-specific mutex exist already in the informers cache + // If not, start by creating such informer-specific mutex. That is used later to ensure no two goroutines create + // informers for the same GVK at the same type + f.informersMutex.Lock() + // Note: the informers cache is protected by informersMutex, which we don't want to hold for very long because + // that blocks CacheFor for other GVKs, hence not deferring unlock here + gi, ok := f.informers[gvk] + if !ok { + gi = &guardedInformer{ + informer: nil, + mutex: &sync.Mutex{}, + } + f.informers[gvk] = gi + } + f.informersMutex.Unlock() + + // At this point an informer-specific mutex (gi.mutex) is guaranteed to exist. Lock it + gi.mutex.Lock() + defer gi.mutex.Unlock() + + // Then: if the informer really was not created yet (first time here or previous times have errored out) + // actually create the informer + if gi.informer == nil { + start := time.Now() + log.Debugf("CacheFor STARTS creating informer for %v", gvk) + defer func() { + log.Debugf("CacheFor IS DONE creating informer for %v (took %v)", gvk, time.Now().Sub(start)) + }() + + _, encryptResourceAlways := defaultEncryptedResourceTypes[gvk] + shouldEncrypt := f.encryptAll || encryptResourceAlways + i, err := f.newInformer(client, fields, transform, gvk, f.dbClient, shouldEncrypt, namespaced) + if err != nil { + return Cache{}, err + } + + err = i.SetWatchErrorHandler(func(r *cache.Reflector, err error) { + if !watchable && errors.IsMethodNotSupported(err) { + // expected, continue without logging + return + } + cache.DefaultWatchErrorHandler(r, err) + }) + if err != nil { + return Cache{}, err + } + + f.wg.StartWithChannel(f.stopCh, i.Run) + + gi.informer = i + } + + if !cache.WaitForCacheSync(f.stopCh, gi.informer.HasSynced) { + return Cache{}, fmt.Errorf("failed to sync SQLite Informer cache for GVK %v", gvk) + } + + // At this point the informer is ready, return it + return Cache{ByOptionsLister: gi.informer}, nil +} + +// Reset closes the stopCh which stops any running informers, assigns a new stopCh, resets the GVK-informer cache, and resets +// the database connection which wipes any current sqlite database at the default location. +func (f *CacheFactory) Reset() error { + if f.dbClient == nil { + // nothing to reset + return nil + } + + // first of all wait until all CacheFor() calls that create new informers are finished. Also block any new ones + f.mutex.Lock() + defer f.mutex.Unlock() + + // now that we are alone, stop all informers created until this point + close(f.stopCh) + f.stopCh = make(chan struct{}) + f.wg.Wait() + + // and get rid of all references to those informers and their mutexes + f.informersMutex.Lock() + defer f.informersMutex.Unlock() + f.informers = make(map[schema.GroupVersionKind]*guardedInformer) + + // finally, reset the DB connection + err := f.dbClient.NewConnection() + if err != nil { + return err + } + + return nil +} diff --git a/pkg/sqlcache/informer/factory/informer_factory_test.go b/pkg/sqlcache/informer/factory/informer_factory_test.go new file mode 100644 index 00000000..0f77435b --- /dev/null +++ b/pkg/sqlcache/informer/factory/informer_factory_test.go @@ -0,0 +1,287 @@ +package factory + +import ( + "os" + "testing" + "time" + + "github.com/rancher/lasso/pkg/cache/sql/informer" + + sqlStore "github.com/rancher/lasso/pkg/cache/sql/store" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer/factory DBClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer + +func TestNewCacheFactory(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned, should return no errors", test: func(t *testing.T) { + f, err := NewCacheFactory() + assert.Nil(t, err) + assert.NotNil(t, f.dbClient) + assert.False(t, f.encryptAll) + }}) + tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned and EncryptAllEnvVar set to true, should return no errors and have encryptAll set to true", test: func(t *testing.T) { + err := os.Setenv(EncryptAllEnvVar, "true") + assert.Nil(t, err) + f, err := NewCacheFactory() + assert.Nil(t, err) + assert.Nil(t, err) + assert.NotNil(t, f.dbClient) + assert.True(t, f.encryptAll) + }}) + // cannot run as parallel because tests involve changing env var + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestCacheFor(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh not closed, should return no error and should call Informer.Run(). A subsequent call to CacheFor() should return same informer", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true).AnyTimes() + sii.EXPECT().Run(gomock.Any()).MinTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + // this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually + time.Sleep(5 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + // this sleep is critical to the test. It ensure there has been enough time for expected function like Run to be invoked in their go routines. + time.Sleep(1 * time.Second) + c2, err := f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, c, c2) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning false, and stopCh not closed, should call Run() and return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(false).AnyTimes() + sii.EXPECT().Run(gomock.Any()) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + expectedI := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return expectedI, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + time.Sleep(1 * time.Second) + close(f.stopCh) + }() + var err error + _, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.NotNil(t, err) + time.Sleep(2 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh closed, should not call Run() more than once and not return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true).AnyTimes() + // may or may not call run initially + sii.EXPECT().Run(gomock.Any()).MaxTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + close(f.stopCh) + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned and encryptAll set to true, should return no error and pass shouldEncrypt as true to newInformer func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true) + sii.EXPECT().Run(gomock.Any()).MinTimes(1).AnyTimes() + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, true, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + encryptAll: true, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + time.Sleep(10 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, stopCh not closed, and transform func should return no error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true) + sii.EXPECT().Run(gomock.Any()).MinTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + // we can't test func == func, so instead we check if the output was as expected + input := "someinput" + ouput, err := transform(input) + assert.Nil(t, err) + outputStr, ok := ouput.(string) + assert.True(t, ok, "ouput from transform was expected to be a string") + assert.Equal(t, "someoutput", outputStr) + + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + // this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually + time.Sleep(5 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, transformFunc, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go b/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go new file mode 100644 index 00000000..b9c4dc35 --- /dev/null +++ b/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go @@ -0,0 +1,223 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/tools/cache (interfaces: SharedIndexInformer) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer +// + +// Package factory is a generated GoMock package. +package factory + +import ( + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" + cache "k8s.io/client-go/tools/cache" +) + +// MockSharedIndexInformer is a mock of SharedIndexInformer interface. +type MockSharedIndexInformer struct { + ctrl *gomock.Controller + recorder *MockSharedIndexInformerMockRecorder +} + +// MockSharedIndexInformerMockRecorder is the mock recorder for MockSharedIndexInformer. +type MockSharedIndexInformerMockRecorder struct { + mock *MockSharedIndexInformer +} + +// NewMockSharedIndexInformer creates a new mock instance. +func NewMockSharedIndexInformer(ctrl *gomock.Controller) *MockSharedIndexInformer { + mock := &MockSharedIndexInformer{ctrl: ctrl} + mock.recorder = &MockSharedIndexInformerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSharedIndexInformer) EXPECT() *MockSharedIndexInformerMockRecorder { + return m.recorder +} + +// AddEventHandler mocks base method. +func (m *MockSharedIndexInformer) AddEventHandler(arg0 cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEventHandler", arg0) + ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEventHandler indicates an expected call of AddEventHandler. +func (mr *MockSharedIndexInformerMockRecorder) AddEventHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandler), arg0) +} + +// AddEventHandlerWithResyncPeriod mocks base method. +func (m *MockSharedIndexInformer) AddEventHandlerWithResyncPeriod(arg0 cache.ResourceEventHandler, arg1 time.Duration) (cache.ResourceEventHandlerRegistration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEventHandlerWithResyncPeriod", arg0, arg1) + ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEventHandlerWithResyncPeriod indicates an expected call of AddEventHandlerWithResyncPeriod. +func (mr *MockSharedIndexInformerMockRecorder) AddEventHandlerWithResyncPeriod(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandlerWithResyncPeriod", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandlerWithResyncPeriod), arg0, arg1) +} + +// AddIndexers mocks base method. +func (m *MockSharedIndexInformer) AddIndexers(arg0 cache.Indexers) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddIndexers", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddIndexers indicates an expected call of AddIndexers. +func (mr *MockSharedIndexInformerMockRecorder) AddIndexers(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIndexers", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddIndexers), arg0) +} + +// GetController mocks base method. +func (m *MockSharedIndexInformer) GetController() cache.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetController") + ret0, _ := ret[0].(cache.Controller) + return ret0 +} + +// GetController indicates an expected call of GetController. +func (mr *MockSharedIndexInformerMockRecorder) GetController() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetController", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetController)) +} + +// GetIndexer mocks base method. +func (m *MockSharedIndexInformer) GetIndexer() cache.Indexer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIndexer") + ret0, _ := ret[0].(cache.Indexer) + return ret0 +} + +// GetIndexer indicates an expected call of GetIndexer. +func (mr *MockSharedIndexInformerMockRecorder) GetIndexer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndexer", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetIndexer)) +} + +// GetStore mocks base method. +func (m *MockSharedIndexInformer) GetStore() cache.Store { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStore") + ret0, _ := ret[0].(cache.Store) + return ret0 +} + +// GetStore indicates an expected call of GetStore. +func (mr *MockSharedIndexInformerMockRecorder) GetStore() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetStore)) +} + +// HasSynced mocks base method. +func (m *MockSharedIndexInformer) HasSynced() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasSynced") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasSynced indicates an expected call of HasSynced. +func (mr *MockSharedIndexInformerMockRecorder) HasSynced() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSynced", reflect.TypeOf((*MockSharedIndexInformer)(nil).HasSynced)) +} + +// IsStopped mocks base method. +func (m *MockSharedIndexInformer) IsStopped() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsStopped") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsStopped indicates an expected call of IsStopped. +func (mr *MockSharedIndexInformerMockRecorder) IsStopped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStopped", reflect.TypeOf((*MockSharedIndexInformer)(nil).IsStopped)) +} + +// LastSyncResourceVersion mocks base method. +func (m *MockSharedIndexInformer) LastSyncResourceVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastSyncResourceVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// LastSyncResourceVersion indicates an expected call of LastSyncResourceVersion. +func (mr *MockSharedIndexInformerMockRecorder) LastSyncResourceVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastSyncResourceVersion", reflect.TypeOf((*MockSharedIndexInformer)(nil).LastSyncResourceVersion)) +} + +// RemoveEventHandler mocks base method. +func (m *MockSharedIndexInformer) RemoveEventHandler(arg0 cache.ResourceEventHandlerRegistration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveEventHandler", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveEventHandler indicates an expected call of RemoveEventHandler. +func (mr *MockSharedIndexInformerMockRecorder) RemoveEventHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).RemoveEventHandler), arg0) +} + +// Run mocks base method. +func (m *MockSharedIndexInformer) Run(arg0 <-chan struct{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Run", arg0) +} + +// Run indicates an expected call of Run. +func (mr *MockSharedIndexInformerMockRecorder) Run(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSharedIndexInformer)(nil).Run), arg0) +} + +// SetTransform mocks base method. +func (m *MockSharedIndexInformer) SetTransform(arg0 cache.TransformFunc) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetTransform", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetTransform indicates an expected call of SetTransform. +func (mr *MockSharedIndexInformerMockRecorder) SetTransform(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTransform", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetTransform), arg0) +} + +// SetWatchErrorHandler mocks base method. +func (m *MockSharedIndexInformer) SetWatchErrorHandler(arg0 cache.WatchErrorHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWatchErrorHandler", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWatchErrorHandler indicates an expected call of SetWatchErrorHandler. +func (mr *MockSharedIndexInformerMockRecorder) SetWatchErrorHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWatchErrorHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetWatchErrorHandler), arg0) +} diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go new file mode 100644 index 00000000..14305339 --- /dev/null +++ b/pkg/sqlcache/informer/indexer.go @@ -0,0 +1,264 @@ +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/rancher/lasso/pkg/cache/sql/db" + "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + "k8s.io/client-go/tools/cache" +) + +const ( + selectQueryFmt = ` + SELECT object, objectnonce, dekid FROM "%[1]s" + WHERE key IN ( + SELECT key FROM "%[1]s_indices" + WHERE name = ? AND value IN (?%s) + ) + ` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%[1]s_indices" ( + name TEXT NOT NULL, + value TEXT NOT NULL, + key TEXT NOT NULL REFERENCES "%[1]s"(key) ON DELETE CASCADE, + PRIMARY KEY (name, value, key) + )` + createIndexFmt = `CREATE INDEX IF NOT EXISTS "%[1]s_indices_index" ON "%[1]s_indices"(name, value)` + + deleteIndicesFmt = `DELETE FROM "%s_indices" WHERE key = ?` + addIndexFmt = `INSERT INTO "%s_indices" (name, value, key) VALUES (?, ?, ?) ON CONFLICT DO NOTHING` + listByIndexFmt = `SELECT object, objectnonce, dekid FROM "%[1]s" + WHERE key IN ( + SELECT key FROM "%[1]s_indices" + WHERE name = ? AND value = ? + )` + listKeyByIndexFmt = `SELECT DISTINCT key FROM "%s_indices" WHERE name = ? AND value = ?` + listIndexValuesFmt = `SELECT DISTINCT value FROM "%s_indices" WHERE name = ?` +) + +// Indexer is a SQLite-backed cache.Indexer which builds upon Store adding an index table +type Indexer struct { + Store + indexers cache.Indexers + indexersLock sync.RWMutex + + deleteIndicesQuery string + addIndexQuery string + listByIndexQuery string + listKeysByIndexQuery string + listIndexValuesQuery string + + deleteIndicesStmt *sql.Stmt + addIndexStmt *sql.Stmt + listByIndexStmt *sql.Stmt + listKeysByIndexStmt *sql.Stmt + listIndexValuesStmt *sql.Stmt +} + +var _ cache.Indexer = (*Indexer)(nil) + +type Store interface { + DBClient + cache.Store + + GetByKey(key string) (item any, exists bool, err error) + GetName() string + RegisterAfterUpsert(f func(key string, obj any, tx db.TXClient) error) + RegisterAfterDelete(f func(key string, tx db.TXClient) error) + GetShouldEncrypt() bool + GetType() reflect.Type +} + +type DBClient interface { + BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows db.Rows) ([]string, error) + ReadInt(rows db.Rows) (int, error) + Prepare(stmt string) *sql.Stmt + CloseStmt(stmt db.Closable) error +} + +// NewIndexer returns a cache.Indexer backed by SQLite for objects of the given example type +func NewIndexer(indexers cache.Indexers, s Store) (*Indexer, error) { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + dbName := db.Sanitize(s.GetName()) + createTableQuery := fmt.Sprintf(createTableFmt, dbName) + err = tx.Exec(createTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createTableQuery, Err: err} + } + createIndexQuery := fmt.Sprintf(createIndexFmt, dbName) + err = tx.Exec(createIndexQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createIndexQuery, Err: err} + } + err = tx.Commit() + if err != nil { + return nil, err + } + + i := &Indexer{ + Store: s, + indexers: indexers, + } + i.RegisterAfterUpsert(i.AfterUpsert) + + i.deleteIndicesQuery = fmt.Sprintf(deleteIndicesFmt, db.Sanitize(s.GetName())) + i.addIndexQuery = fmt.Sprintf(addIndexFmt, db.Sanitize(s.GetName())) + i.listByIndexQuery = fmt.Sprintf(listByIndexFmt, db.Sanitize(s.GetName())) + i.listKeysByIndexQuery = fmt.Sprintf(listKeyByIndexFmt, db.Sanitize(s.GetName())) + i.listIndexValuesQuery = fmt.Sprintf(listIndexValuesFmt, db.Sanitize(s.GetName())) + + i.deleteIndicesStmt = s.Prepare(i.deleteIndicesQuery) + i.addIndexStmt = s.Prepare(i.addIndexQuery) + i.listByIndexStmt = s.Prepare(i.listByIndexQuery) + i.listKeysByIndexStmt = s.Prepare(i.listKeysByIndexQuery) + i.listIndexValuesStmt = s.Prepare(i.listIndexValuesQuery) + + return i, nil +} + +/* Core methods */ + +// AfterUpsert updates indices of an object +func (i *Indexer) AfterUpsert(key string, obj any, tx db.TXClient) error { + // delete all + err := tx.StmtExec(tx.Stmt(i.deleteIndicesStmt), key) + if err != nil { + return &db.QueryError{QueryString: i.deleteIndicesQuery, Err: err} + } + + // re-insert all values + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + for indexName, indexFunc := range i.indexers { + values, err := indexFunc(obj) + if err != nil { + return err + } + + for _, value := range values { + err = tx.StmtExec(tx.Stmt(i.addIndexStmt), indexName, value, key) + if err != nil { + return &db.QueryError{QueryString: i.addIndexQuery, Err: err} + } + } + } + return nil +} + +/* Satisfy cache.Indexer */ + +// Index returns a list of items that match the given object on the index function +func (i *Indexer) Index(indexName string, obj any) ([]any, error) { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + indexFunc := i.indexers[indexName] + if indexFunc == nil { + return nil, fmt.Errorf("index with name %s does not exist", indexName) + } + + values, err := indexFunc(obj) + if err != nil { + return nil, err + } + + if len(values) == 0 { + return nil, nil + } + + // typical case + if len(values) == 1 { + return i.ByIndex(indexName, values[0]) + } + + // atypical case - more than one value to lookup + // HACK: sql.Statement.Query does not allow to pass slices in as of go 1.19 - create an ad-hoc statement + query := fmt.Sprintf(selectQueryFmt, db.Sanitize(i.GetName()), strings.Repeat(", ?", len(values)-1)) + stmt := i.Prepare(query) + defer i.CloseStmt(stmt) + // HACK: Query will accept []any but not []string + params := []any{indexName} + for _, value := range values { + params = append(params, value) + } + + rows, err := i.QueryForRows(context.TODO(), stmt, params...) + if err != nil { + return nil, &db.QueryError{QueryString: query, Err: err} + } + return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt()) +} + +// ByIndex returns the stored objects whose set of indexed values +// for the named index includes the given indexed value +func (i *Indexer) ByIndex(indexName, indexedValue string) ([]any, error) { + rows, err := i.QueryForRows(context.TODO(), i.listByIndexStmt, indexName, indexedValue) + if err != nil { + return nil, &db.QueryError{QueryString: i.listByIndexQuery, Err: err} + } + return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt()) +} + +// IndexKeys returns a list of the Store keys of the objects whose indexed values in the given index include the given indexed value +func (i *Indexer) IndexKeys(indexName, indexedValue string) ([]string, error) { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + indexFunc := i.indexers[indexName] + if indexFunc == nil { + return nil, fmt.Errorf("Index with name %s does not exist", indexName) + } + + rows, err := i.QueryForRows(context.TODO(), i.listKeysByIndexStmt, indexName, indexedValue) + if err != nil { + return nil, &db.QueryError{QueryString: i.listKeysByIndexQuery, Err: err} + } + return i.ReadStrings(rows) +} + +// ListIndexFuncValues wraps safeListIndexFuncValues and panics in case of I/O errors +func (i *Indexer) ListIndexFuncValues(name string) []string { + result, err := i.safeListIndexFuncValues(name) + if err != nil { + panic(fmt.Errorf("unexpected error in safeListIndexFuncValues: %w", err)) + } + return result +} + +// safeListIndexFuncValues returns all the indexed values of the given index +func (i *Indexer) safeListIndexFuncValues(indexName string) ([]string, error) { + rows, err := i.QueryForRows(context.TODO(), i.listIndexValuesStmt, indexName) + if err != nil { + return nil, &db.QueryError{QueryString: i.listIndexValuesQuery, Err: err} + } + return i.ReadStrings(rows) +} + +// GetIndexers returns the indexers +func (i *Indexer) GetIndexers() cache.Indexers { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + return i.indexers +} + +// AddIndexers adds more indexers to this Store. If you call this after you already have data +// in the Store, the results are undefined. +func (i *Indexer) AddIndexers(newIndexers cache.Indexers) error { + i.indexersLock.Lock() + defer i.indexersLock.Unlock() + if i.indexers == nil { + i.indexers = make(map[string]cache.IndexFunc) + } + for k, v := range newIndexers { + i.indexers[k] = v + } + return nil +} diff --git a/pkg/sqlcache/informer/indexer_test.go b/pkg/sqlcache/informer/indexer_test.go new file mode 100644 index 00000000..a861efb6 --- /dev/null +++ b/pkg/sqlcache/informer/indexer_test.go @@ -0,0 +1,614 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer Store +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows + +type testStoreObject struct { + Id string + Val string +} + +func TestNewIndexer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewIndexer() with no errors returned from Store or TXClient, should return no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) + client.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(fmt.Sprintf(deleteIndicesFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(addIndexFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listByIndexFmt, storeName, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listKeyByIndexFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listIndexValuesFmt, storeName)) + indexer, err := NewIndexer(indexers, store) + assert.Nil(t, err) + assert.Equal(t, cache.Indexers(indexers), indexer.indexers) + }}) + tests = append(tests, testCase{description: "NewIndexer() with Store Begin() error, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + store.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on first call to Exec(), should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on second call to Exec(), should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Commit() error, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) + client.EXPECT().Commit().Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestAfterUpsert(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "AfterUpsert() with no errors returned from TXClient should return no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) + client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(nil) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient StmtExec() should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(fmt.Errorf("error")) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient second StmtExec() call should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) + client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(fmt.Errorf("error")) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestIndex(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Index() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject, testObject}, objs) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{}, objs) + }}) + tests = append(tests, testCase{description: "Index() where index name is not in indexers, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + _, err := indexer.Index("someotherindexname", testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error")) + _, err := indexer.Index(indexName, testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error")) + _, err := indexer.Index(indexName, testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple keys returned from index func, should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey, objKey + "2"}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().GetName().Return("name") + stmt := &sql.Stmt{} + store.EXPECT().Prepare(fmt.Sprintf(selectQueryFmt, "name", ", ?")).Return(stmt) + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey, objKey+"2").Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestByIndex(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{testObject, testObject}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error")) + _, err := indexer.ByIndex(indexName, objKey) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "IndexBy() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error")) + _, err := indexer.ByIndex(indexName, objKey) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestListIndexFuncValues(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListIndexFuncvalues() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil) + store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, nil) + vals := indexer.ListIndexFuncValues(indexName) + assert.Equal(t, []string{"somestrings"}, vals) + }}) + tests = append(tests, testCase{description: "ListIndexFuncvalues() with QueryForRows() error returned from store, should panic", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(nil, fmt.Errorf("error")) + assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) }) + }}) + tests = append(tests, testCase{description: "ListIndexFuncvalues() with ReadStrings() error returned from store, should panic", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil) + store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, fmt.Errorf("error")) + assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) }) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestGetIndexers(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) { + objKey := "key" + expectedIndexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + indexer := &Indexer{ + indexers: expectedIndexers, + } + indexers := indexer.GetIndexers() + assert.Equal(t, cache.Indexers(expectedIndexers), indexers) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestAddIndexers(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) { + objKey := "key" + expectedIndexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + indexer := &Indexer{} + err := indexer.AddIndexers(expectedIndexers) + assert.Nil(t, err) + assert.ObjectsAreEqual(cache.Indexers(expectedIndexers), indexer.indexers) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/informer/informer.go b/pkg/sqlcache/informer/informer.go new file mode 100644 index 00000000..b893713a --- /dev/null +++ b/pkg/sqlcache/informer/informer.go @@ -0,0 +1,94 @@ +/* +package sql provides an Informer and Indexer that uses SQLite as a store, instead of an in-memory store like a map. +*/ + +package informer + +import ( + "context" + "time" + + "github.com/rancher/lasso/pkg/cache/sql/partition" + sqlStore "github.com/rancher/lasso/pkg/cache/sql/store" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +// Informer is a SQLite-backed cache.SharedIndexInformer that can execute queries on listprocessor structs +type Informer struct { + cache.SharedIndexInformer + ByOptionsLister +} + +type ByOptionsLister interface { + ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) +} + +// this is set to a var so that it can be overridden by test code for mocking purposes +var newInformer = cache.NewSharedIndexInformer + +// NewInformer returns a new SQLite-backed Informer for the type specified by schema in unstructured.Unstructured form +// using the specified client +func NewInformer(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*Informer, error) { + listWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + a, err := client.List(context.Background(), options) + return a, err + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + return client.Watch(context.Background(), options) + }, + } + + example := &unstructured.Unstructured{} + example.SetGroupVersionKind(gvk) + + // avoids the informer to periodically resync (re-list) its resources + // currently it is a work hypothesis that, when interacting with the UI, this should not be needed + resyncPeriod := time.Duration(0) + + sii := newInformer(listWatcher, example, resyncPeriod, cache.Indexers{}) + if transform != nil { + if err := sii.SetTransform(transform); err != nil { + return nil, err + } + } + + name := informerNameFromGVK(gvk) + + s, err := sqlStore.NewStore(example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, shouldEncrypt, name) + if err != nil { + return nil, err + } + loi, err := NewListOptionIndexer(fields, s, namespaced) + if err != nil { + return nil, err + } + + // HACK: replace the default informer's indexer with the SQL based one + UnsafeSet(sii, "indexer", loi) + + return &Informer{ + SharedIndexInformer: sii, + ByOptionsLister: loi, + }, nil +} + +// ListByOptions returns objects according to the specified list options and partitions. +// Specifically: +// - an unstructured list of resources belonging to any of the specified partitions +// - the total number of resources (returned list might be a subset depending on pagination options in lo) +// - a continue token, if there are more pages after the returned one +// - an error instead of all of the above if anything went wrong +func (i *Informer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { + return i.ByOptionsLister.ListByOptions(ctx, lo, partitions, namespace) +} + +func informerNameFromGVK(gvk schema.GroupVersionKind) string { + return gvk.Group + "_" + gvk.Version + "_" + gvk.Kind +} diff --git a/pkg/sqlcache/informer/informer_mocks_test.go b/pkg/sqlcache/informer/informer_mocks_test.go new file mode 100644 index 00000000..ae6bf4b6 --- /dev/null +++ b/pkg/sqlcache/informer/informer_mocks_test.go @@ -0,0 +1,59 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: ByOptionsLister) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + reflect "reflect" + + partition "github.com/rancher/lasso/pkg/cache/sql/partition" + gomock "go.uber.org/mock/gomock" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +// MockByOptionsLister is a mock of ByOptionsLister interface. +type MockByOptionsLister struct { + ctrl *gomock.Controller + recorder *MockByOptionsListerMockRecorder +} + +// MockByOptionsListerMockRecorder is the mock recorder for MockByOptionsLister. +type MockByOptionsListerMockRecorder struct { + mock *MockByOptionsLister +} + +// NewMockByOptionsLister creates a new mock instance. +func NewMockByOptionsLister(ctrl *gomock.Controller) *MockByOptionsLister { + mock := &MockByOptionsLister{ctrl: ctrl} + mock.recorder = &MockByOptionsListerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockByOptionsLister) EXPECT() *MockByOptionsListerMockRecorder { + return m.recorder +} + +// ListByOptions mocks base method. +func (m *MockByOptionsLister) ListByOptions(arg0 context.Context, arg1 ListOptions, arg2 []partition.Partition, arg3 string) (*unstructured.UnstructuredList, int, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByOptions", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(string) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// ListByOptions indicates an expected call of ListByOptions. +func (mr *MockByOptionsListerMockRecorder) ListByOptions(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByOptions", reflect.TypeOf((*MockByOptionsLister)(nil).ListByOptions), arg0, arg1, arg2, arg3) +} diff --git a/pkg/sqlcache/informer/informer_test.go b/pkg/sqlcache/informer/informer_test.go new file mode 100644 index 00000000..5199bdc8 --- /dev/null +++ b/pkg/sqlcache/informer/informer_test.go @@ -0,0 +1,351 @@ +package informer + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient + +func TestNewInformer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewInformer() with no errors returned, should return no error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(context.Background(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + informer, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.Nil(t, err) + assert.NotNil(t, informer.ByOptionsLister) + assert.NotNil(t, informer.SharedIndexInformer) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewStore(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewIndexer(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewListOptionIndexer(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with transform func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + mockInformer := mockInformer{} + testNewInformer := func(lw cache.ListerWatcher, + exampleObject runtime.Object, + defaultEventHandlerResyncPeriod time.Duration, + indexers cache.Indexers) cache.SharedIndexInformer { + return &mockInformer + } + newInformer = testNewInformer + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + informer, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true) + assert.Nil(t, err) + assert.NotNil(t, informer.ByOptionsLister) + assert.NotNil(t, informer.SharedIndexInformer) + assert.NotNil(t, mockInformer.transformFunc) + + // we can't test func == func, so instead we check if the output was as expected + input := "someinput" + ouput, err := mockInformer.transformFunc(input) + assert.Nil(t, err) + outputStr, ok := ouput.(string) + assert.True(t, ok, "ouput from transform was expected to be a string") + assert.Equal(t, "someoutput", outputStr) + + newInformer = cache.NewSharedIndexInformer + }}) + tests = append(tests, testCase{description: "NewInformer() unable to set transform func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + mockInformer := mockInformer{ + setTranformErr: fmt.Errorf("some error"), + } + testNewInformer := func(lw cache.ListerWatcher, + exampleObject runtime.Object, + defaultEventHandlerResyncPeriod time.Duration, + indexers cache.Indexers) cache.SharedIndexInformer { + return &mockInformer + } + newInformer = testNewInformer + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + _, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true) + assert.Error(t, err) + newInformer = cache.NewSharedIndexInformer + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestInformerListByOptions(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListByOptions() with no errors returned, should return no error and return value from indexer's ListByOptions()", test: func(t *testing.T) { + indexer := NewMockByOptionsLister(gomock.NewController(t)) + informer := &Informer{ + ByOptionsLister: indexer, + } + lo := ListOptions{} + var partitions []partition.Partition + ns := "somens" + expectedList := &unstructured.UnstructuredList{ + Object: map[string]interface{}{"s": 2}, + Items: []unstructured.Unstructured{{ + Object: map[string]interface{}{"s": 2}, + }}, + } + expectedTotal := len(expectedList.Items) + expectedContinueToken := "123" + indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(expectedList, expectedTotal, expectedContinueToken, nil) + list, total, continueToken, err := informer.ListByOptions(context.TODO(), lo, partitions, ns) + assert.Nil(t, err) + assert.Equal(t, expectedList, list) + assert.Equal(t, len(expectedList.Items), total) + assert.Equal(t, expectedContinueToken, continueToken) + }}) + tests = append(tests, testCase{description: "ListByOptions() with indexer ListByOptions error, should return error", test: func(t *testing.T) { + indexer := NewMockByOptionsLister(gomock.NewController(t)) + informer := &Informer{ + ByOptionsLister: indexer, + } + lo := ListOptions{} + var partitions []partition.Partition + ns := "somens" + indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(nil, 0, "", fmt.Errorf("error")) + _, _, _, err := informer.ListByOptions(context.TODO(), lo, partitions, ns) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +// Note: SQLite based caching uses an Informer that unsafely sets the Indexer as the ability to set it is not present +// in client-go at the moment. Long term, we look forward contribute a patch to client-go to make that configurable. +// Until then, we are adding this canary test that will panic in case the indexer cannot be set. +func TestUnsafeSet(t *testing.T) { + listWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + return &unstructured.UnstructuredList{}, nil + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + return dummyWatch{}, nil + }, + } + + sii := cache.NewSharedIndexInformer(listWatcher, &unstructured.Unstructured{}, 0, cache.Indexers{}) + + // will panic if SharedIndexInformer stops having a *Indexer field called "indexer" + UnsafeSet(sii, "indexer", &Indexer{}) +} + +type dummyWatch struct{} + +func (dummyWatch) Stop() { +} + +func (dummyWatch) ResultChan() <-chan watch.Event { + result := make(chan watch.Event) + defer close(result) + return result +} + +// mockInformer is a mock of cache.SharedIndexInformer. Unlike other types, we can't generate this using mockgen because we use a unsafeSet to replace the +// indexer field, which is a struct field. This won't exist on the mock, producing an error. So we need to implement our own mock which actually has this field. +type mockInformer struct { + transformFunc cache.TransformFunc + setTranformErr error + indexer cache.Indexer +} + +func (m *mockInformer) AddEventHandler(handler cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + return nil, nil +} +func (m *mockInformer) AddEventHandlerWithResyncPeriod(handler cache.ResourceEventHandler, resyncPeriod time.Duration) (cache.ResourceEventHandlerRegistration, error) { + return nil, nil +} +func (m *mockInformer) RemoveEventHandler(handle cache.ResourceEventHandlerRegistration) error { + return nil +} +func (m *mockInformer) GetStore() cache.Store { return nil } +func (m *mockInformer) GetController() cache.Controller { return nil } +func (m *mockInformer) Run(stopCh <-chan struct{}) {} +func (m *mockInformer) HasSynced() bool { return false } +func (m *mockInformer) LastSyncResourceVersion() string { return "" } +func (m *mockInformer) SetWatchErrorHandler(handler cache.WatchErrorHandler) error { return nil } +func (m *mockInformer) IsStopped() bool { return false } +func (m *mockInformer) AddIndexers(indexers cache.Indexers) error { return nil } +func (m *mockInformer) GetIndexer() cache.Indexer { return nil } +func (m *mockInformer) SetTransform(handler cache.TransformFunc) error { + m.transformFunc = handler + return m.setTranformErr +} diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go new file mode 100644 index 00000000..3290cf70 --- /dev/null +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -0,0 +1,853 @@ +package informer + +import ( + "context" + "database/sql" + "encoding/gob" + "errors" + "fmt" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/client-go/tools/cache" + + "github.com/rancher/lasso/pkg/cache/sql/db" + "github.com/rancher/lasso/pkg/cache/sql/partition" +) + +// ListOptionIndexer extends Indexer by allowing queries based on ListOption +type ListOptionIndexer struct { + *Indexer + + namespaced bool + indexedFields []string + + addFieldQuery string + deleteFieldQuery string + upsertLabelsQuery string + deleteLabelsQuery string + + addFieldStmt *sql.Stmt + deleteFieldStmt *sql.Stmt + upsertLabelsStmt *sql.Stmt + deleteLabelsStmt *sql.Stmt +} + +var ( + defaultIndexedFields = []string{"metadata.name", "metadata.creationTimestamp"} + defaultIndexNamespaced = "metadata.namespace" + subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[a-zA-Z./]+])|(\[[0-9]+])`) + + InvalidColumnErr = errors.New("supplied column is invalid") +) + +const ( + matchFmt = `%%%s%%` + strictMatchFmt = `%s` + escapeBackslashDirective = ` ESCAPE '\'` // The leading space is crucial for unit tests only ' + createFieldsTableFmt = `CREATE TABLE "%s_fields" ( + key TEXT NOT NULL PRIMARY KEY, + %s + )` + createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")` + + failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items: %w" + + createLabelsTableFmt = `CREATE TABLE IF NOT EXISTS "%s_labels" ( + key TEXT NOT NULL REFERENCES "%s"(key) ON DELETE CASCADE, + label TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (key, label) + )` + createLabelsTableIndexFmt = `CREATE INDEX IF NOT EXISTS "%s_labels_index" ON "%s_labels"(label, value)` + + upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)` + deleteLabelsStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?` +) + +// NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK +// ListOptionIndexer is also able to satisfy ListOption queries on indexed (sub)fields. +// Fields are specified as slices (e.g. "metadata.resourceVersion" is ["metadata", "resourceVersion"]) +func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOptionIndexer, error) { + // necessary in order to gob/ungob unstructured.Unstructured objects + gob.Register(map[string]interface{}{}) + gob.Register([]interface{}{}) + + i, err := NewIndexer(cache.Indexers{}, s) + if err != nil { + return nil, err + } + + var indexedFields []string + for _, f := range defaultIndexedFields { + indexedFields = append(indexedFields, f) + } + if namespaced { + indexedFields = append(indexedFields, defaultIndexNamespaced) + } + for _, f := range fields { + indexedFields = append(indexedFields, toColumnName(f)) + } + + l := &ListOptionIndexer{ + Indexer: i, + namespaced: namespaced, + indexedFields: indexedFields, + } + l.RegisterAfterUpsert(l.addIndexFields) + l.RegisterAfterUpsert(l.addLabels) + l.RegisterAfterDelete(l.deleteIndexFields) + l.RegisterAfterDelete(l.deleteLabels) + columnDefs := make([]string, len(indexedFields)) + for index, field := range indexedFields { + column := fmt.Sprintf(`"%s" TEXT`, field) + columnDefs[index] = column + } + + tx, err := l.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + dbName := db.Sanitize(i.GetName()) + err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", "))) + if err != nil { + return nil, err + } + + columns := make([]string, len(indexedFields)) + qmarks := make([]string, len(indexedFields)) + setStatements := make([]string, len(indexedFields)) + + for index, field := range indexedFields { + // create index for field + err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field)) + if err != nil { + return nil, err + } + + // format field into column for prepared statement + column := fmt.Sprintf(`"%s"`, field) + columns[index] = column + + // add placeholder for column's value in prepared statement + qmarks[index] = "?" + + // add formatted set statement for prepared statement + setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field) + setStatements[index] = setStatement + } + createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName) + err = tx.Exec(createLabelsTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createLabelsTableQuery, Err: err} + } + + createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName) + err = tx.Exec(createLabelsTableIndexQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err} + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + l.addFieldQuery = fmt.Sprintf( + `INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`, + dbName, + strings.Join(columns, ", "), + strings.Join(qmarks, ", "), + strings.Join(setStatements, ", "), + ) + l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName) + + l.addFieldStmt = l.Prepare(l.addFieldQuery) + l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery) + + l.upsertLabelsQuery = fmt.Sprintf(upsertLabelsStmtFmt, dbName) + l.deleteLabelsQuery = fmt.Sprintf(deleteLabelsStmtFmt, dbName) + l.upsertLabelsStmt = l.Prepare(l.upsertLabelsQuery) + l.deleteLabelsStmt = l.Prepare(l.deleteLabelsQuery) + + return l, nil +} + +/* Core methods */ + +// addIndexFields saves sortable/filterable fields into tables +func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) error { + args := []any{key} + for _, field := range l.indexedFields { + value, err := getField(obj, field) + if err != nil { + logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err) + cErr := tx.Cancel() + if cErr != nil { + return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err) + } + return err + } + switch typedValue := value.(type) { + case nil: + args = append(args, "") + case int, bool, string: + args = append(args, fmt.Sprint(typedValue)) + case []string: + args = append(args, strings.Join(typedValue, "|")) + default: + err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value) + cErr := tx.Cancel() + if cErr != nil { + return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2) + } + return err2 + } + } + + err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...) + if err != nil { + return &db.QueryError{QueryString: l.addFieldQuery, Err: err} + } + return nil +} + +// labels are stored in tables that shadow the underlying object table for each GVK +func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error { + k8sObj, ok := obj.(*unstructured.Unstructured) + if !ok { + return fmt.Errorf("addLabels: unexpected object type, expected unstructured.Unstructured: %v", obj) + } + incomingLabels := k8sObj.GetLabels() + for k, v := range incomingLabels { + err := tx.StmtExec(tx.Stmt(l.upsertLabelsStmt), key, k, v) + if err != nil { + return &db.QueryError{QueryString: l.upsertLabelsQuery, Err: err} + } + } + return nil +} + +func (l *ListOptionIndexer) deleteIndexFields(key string, tx db.TXClient) error { + args := []any{key} + + err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...) + if err != nil { + return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err} + } + return nil +} + +func (l *ListOptionIndexer) deleteLabels(key string, tx db.TXClient) error { + err := tx.StmtExec(tx.Stmt(l.deleteLabelsStmt), key) + if err != nil { + return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err} + } + return nil +} + +// ListByOptions returns objects according to the specified list options and partitions. +// Specifically: +// - an unstructured list of resources belonging to any of the specified partitions +// - the total number of resources (returned list might be a subset depending on pagination options in lo) +// - a continue token, if there are more pages after the returned one +// - an error instead of all of the above if anything went wrong +func (l *ListOptionIndexer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { + queryInfo, err := l.constructQuery(lo, partitions, namespace, db.Sanitize(l.GetName())) + if err != nil { + return nil, 0, "", err + } + return l.executeQuery(ctx, queryInfo) +} + +// QueryInfo is a helper-struct that is used to represent the core query and parameters when converting +// a filter from the UI into a sql query +type QueryInfo struct { + query string + params []any + countQuery string + countParams []any + limit int + offset int +} + +func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partition.Partition, namespace string, dbName string) (*QueryInfo, error) { + queryInfo := &QueryInfo{} + queryHasLabelFilter := hasLabelFilter(lo.Filters) + + // First, what kind of filtering will we be doing? + // 1- Intro: SELECT and JOIN clauses + // There's a 1:1 correspondence between a base table and its _Fields table + // but it's possible that a key has no associated labels, so if we're doing a + // non-existence test on labels we need to do a LEFT OUTER JOIN + distinctModifier := "" + if queryHasLabelFilter { + distinctModifier = " DISTINCT" + } + query := fmt.Sprintf(`SELECT%s o.object, o.objectnonce, o.dekid FROM "%s" o`, distinctModifier, dbName) + query += "\n " + query += fmt.Sprintf(`JOIN "%s_fields" f ON o.key = f.key`, dbName) + if queryHasLabelFilter { + for i, orFilters := range lo.Filters { + if hasLabelFilter([]OrFilter{orFilters}) { + query += "\n " + // Make the lt index 1-based for readability + query += fmt.Sprintf(`LEFT OUTER JOIN "%s_labels" lt%d ON o.key = lt%d.key`, dbName, i+1, i+1) + } + } + } + params := []any{} + + // 2- Filtering: WHERE clauses (from lo.Filters) + whereClauses := []string{} + for i, orFilters := range lo.Filters { + orClause, orParams, err := l.buildORClauseFromFilters(i+1, orFilters, dbName) + if err != nil { + return queryInfo, err + } + if orClause == "" { + continue + } + whereClauses = append(whereClauses, orClause) + params = append(params, orParams...) + } + + // WHERE clauses (from namespace) + if namespace != "" && namespace != "*" { + whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) + params = append(params, namespace) + } + + // WHERE clauses (from partitions and their corresponding parameters) + partitionClauses := []string{} + for _, thisPartition := range partitions { + if thisPartition.Passthrough { + // nothing to do, no extra filtering to apply by definition + } else { + singlePartitionClauses := []string{} + + // filter by namespace + if thisPartition.Namespace != "" && thisPartition.Namespace != "*" { + singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) + params = append(params, thisPartition.Namespace) + } + + // optionally filter by names + if !thisPartition.All { + names := thisPartition.Names + + if len(names) == 0 { + // degenerate case, there will be no results + singlePartitionClauses = append(singlePartitionClauses, "FALSE") + } else { + singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.name" IN (?%s)`, strings.Repeat(", ?", len(thisPartition.Names)-1))) + // sort for reproducibility + sortedNames := thisPartition.Names.UnsortedList() + sort.Strings(sortedNames) + for _, name := range sortedNames { + params = append(params, name) + } + } + } + + if len(singlePartitionClauses) > 0 { + partitionClauses = append(partitionClauses, strings.Join(singlePartitionClauses, " AND ")) + } + } + } + if len(partitions) == 0 { + // degenerate case, there will be no results + whereClauses = append(whereClauses, "FALSE") + } + if len(partitionClauses) == 1 { + whereClauses = append(whereClauses, partitionClauses[0]) + } + if len(partitionClauses) > 1 { + whereClauses = append(whereClauses, "(\n ("+strings.Join(partitionClauses, ") OR\n (")+")\n)") + } + + if len(whereClauses) > 0 { + query += "\n WHERE\n " + for index, clause := range whereClauses { + query += fmt.Sprintf("(%s)", clause) + if index == len(whereClauses)-1 { + break + } + query += " AND\n " + } + } + + // before proceeding, save a copy of the query and params without LIMIT/OFFSET/ORDER info + // for COUNTing all results later + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s)", query) + countParams := params[:] + + // 3- Sorting: ORDER BY clauses (from lo.Sort) + if len(lo.Sort.Fields) != len(lo.Sort.Orders) { + return nil, fmt.Errorf("sort fields length %d != sort orders length %d", len(lo.Sort.Fields), len(lo.Sort.Orders)) + } + if len(lo.Sort.Fields) > 0 { + orderByClauses := []string{} + for i, field := range lo.Sort.Fields { + columnName := toColumnName(field) + if err := l.validateColumn(columnName); err != nil { + return queryInfo, err + } + + direction := "ASC" + if lo.Sort.Orders[i] == DESC { + direction = "DESC" + } + orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction)) + } + query += "\n ORDER BY " + query += strings.Join(orderByClauses, ", ") + } else { + // make sure one default order is always picked + if l.namespaced { + query += "\n ORDER BY f.\"metadata.namespace\" ASC, f.\"metadata.name\" ASC " + } else { + query += "\n ORDER BY f.\"metadata.name\" ASC " + } + } + + // 4- Pagination: LIMIT clause (from lo.Pagination and/or lo.ChunkSize/lo.Resume) + + limitClause := "" + // take the smallest limit between lo.Pagination and lo.ChunkSize + limit := lo.Pagination.PageSize + if limit == 0 || (lo.ChunkSize > 0 && lo.ChunkSize < limit) { + limit = lo.ChunkSize + } + if limit > 0 { + limitClause = "\n LIMIT ?" + params = append(params, limit) + } + + // OFFSET clause (from lo.Pagination and/or lo.Resume) + offsetClause := "" + offset := 0 + if lo.Resume != "" { + offsetInt, err := strconv.Atoi(lo.Resume) + if err != nil { + return queryInfo, err + } + offset = offsetInt + } + if lo.Pagination.Page >= 1 { + offset += lo.Pagination.PageSize * (lo.Pagination.Page - 1) + } + if offset > 0 { + offsetClause = "\n OFFSET ?" + params = append(params, offset) + } + if limit > 0 || offset > 0 { + query += limitClause + query += offsetClause + queryInfo.countQuery = countQuery + queryInfo.countParams = countParams + queryInfo.limit = limit + queryInfo.offset = offset + } + // Otherwise leave these as default values and the executor won't do pagination work + + logrus.Debugf("ListOptionIndexer prepared statement: %v", query) + logrus.Debugf("Params: %v", params) + queryInfo.query = query + queryInfo.params = params + + return queryInfo, nil +} + +func (l *ListOptionIndexer) executeQuery(ctx context.Context, queryInfo *QueryInfo) (*unstructured.UnstructuredList, int, string, error) { + stmt := l.Prepare(queryInfo.query) + defer l.CloseStmt(stmt) + + tx, err := l.BeginTx(ctx, false) + if err != nil { + return nil, 0, "", err + } + + txStmt := tx.Stmt(stmt) + rows, err := txStmt.QueryContext(ctx, queryInfo.params...) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", &db.QueryError{QueryString: queryInfo.query, Err: err} + } + items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt()) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", err + } + + total := len(items) + if queryInfo.countQuery != "" { + countStmt := l.Prepare(queryInfo.countQuery) + defer l.CloseStmt(countStmt) + txStmt := tx.Stmt(countStmt) + rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", &db.QueryError{QueryString: queryInfo.countQuery, Err: err} + } + total, err = l.ReadInt(rows) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", fmt.Errorf("error reading query results: %w", err) + } + } + if err := tx.Commit(); err != nil { + return nil, 0, "", err + } + + continueToken := "" + limit := queryInfo.limit + offset := queryInfo.offset + if limit > 0 && offset+len(items) < total { + continueToken = fmt.Sprintf("%d", offset+limit) + } + + return toUnstructuredList(items), total, continueToken, nil +} + +func (l *ListOptionIndexer) validateColumn(column string) error { + for _, v := range l.indexedFields { + if v == column { + return nil + } + } + return fmt.Errorf("column is invalid [%s]: %w", column, InvalidColumnErr) +} + +// buildORClause creates an SQLite compatible query that ORs conditions built from passed filters +func (l *ListOptionIndexer) buildORClauseFromFilters(index int, orFilters OrFilter, dbName string) (string, []any, error) { + var params []any + clauses := make([]string, 0, len(orFilters.Filters)) + var newParams []any + var newClause string + var err error + + for _, filter := range orFilters.Filters { + if isLabelFilter(&filter) { + newClause, newParams, err = l.getLabelFilter(index, filter, dbName) + } else { + newClause, newParams, err = l.getFieldFilter(filter) + } + if err != nil { + return "", nil, err + } + clauses = append(clauses, newClause) + params = append(params, newParams...) + } + switch len(clauses) { + case 0: + return "", params, nil + case 1: + return clauses[0], params, nil + } + return fmt.Sprintf("(%s)", strings.Join(clauses, ") OR (")), params, nil +} + +// Possible ops from the k8s parser: +// KEY = and == (same) VALUE +// KEY != VALUE +// KEY exists [] # ,KEY, => this filter +// KEY ! [] # ,!KEY, => assert KEY doesn't exist +// KEY in VALUES +// KEY notin VALUES + +func (l *ListOptionIndexer) getFieldFilter(filter Filter) (string, []any, error) { + opString := "" + escapeString := "" + columnName := toColumnName(filter.Field) + if err := l.validateColumn(columnName); err != nil { + return "", nil, err + } + switch filter.Op { + case Eq: + if filter.Partial { + opString = "LIKE" + escapeString = escapeBackslashDirective + } else { + opString = "=" + } + clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString) + return clause, []any{formatMatchTarget(filter)}, nil + case NotEq: + if filter.Partial { + opString = "NOT LIKE" + escapeString = escapeBackslashDirective + } else { + opString = "!=" + } + clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString) + return clause, []any{formatMatchTarget(filter)}, nil + + case Lt, Gt: + sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0]) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`f."%s" %s ?`, columnName, sym) + return clause, []any{target}, nil + + case Exists, NotExists: + return "", nil, errors.New("NULL and NOT NULL tests aren't supported for non-label queries") + + case In: + fallthrough + case NotIn: + target := "()" + if len(filter.Matches) > 0 { + target = fmt.Sprintf("(?%s)", strings.Repeat(", ?", len(filter.Matches)-1)) + } + opString = "IN" + if filter.Op == NotIn { + opString = "NOT IN" + } + clause := fmt.Sprintf(`f."%s" %s %s`, columnName, opString, target) + matches := make([]any, len(filter.Matches)) + for i, match := range filter.Matches { + matches[i] = match + } + return clause, matches, nil + } + + return "", nil, fmt.Errorf("unrecognized operator: %s", opString) +} + +func (l *ListOptionIndexer) getLabelFilter(index int, filter Filter, dbName string) (string, []any, error) { + opString := "" + escapeString := "" + matchFmtToUse := strictMatchFmt + labelName := filter.Field[2] + switch filter.Op { + case Eq: + if filter.Partial { + opString = "LIKE" + escapeString = escapeBackslashDirective + matchFmtToUse = matchFmt + } else { + opString = "=" + } + clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value %s ?%s`, index, index, opString, escapeString) + return clause, []any{labelName, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)}, nil + + case NotEq: + if filter.Partial { + opString = "NOT LIKE" + escapeString = escapeBackslashDirective + matchFmtToUse = matchFmt + } else { + opString = "!=" + } + subFilter := Filter{ + Field: filter.Field, + Op: NotExists, + } + existenceClause, subParams, err := l.getLabelFilter(index, subFilter, dbName) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`(%s) OR (lt%d.label = ? AND lt%d.value %s ?%s)`, existenceClause, index, index, opString, escapeString) + params := append(subParams, labelName, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)) + return clause, params, nil + + case Lt, Gt: + sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0]) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value %s ?`, index, index, sym) + return clause, []any{labelName, target}, nil + + case Exists: + clause := fmt.Sprintf(`lt%d.label = ?`, index) + return clause, []any{labelName}, nil + + case NotExists: + clause := fmt.Sprintf(`o.key NOT IN (SELECT o1.key FROM "%s" o1 + JOIN "%s_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "%s_labels" lt%di1 ON o1.key = lt%di1.key + WHERE lt%di1.label = ?)`, dbName, dbName, dbName, index, index, index) + return clause, []any{labelName}, nil + + case In: + target := "(?" + if len(filter.Matches) > 0 { + target += strings.Repeat(", ?", len(filter.Matches)-1) + } + target += ")" + clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value IN %s`, index, index, target) + matches := make([]any, len(filter.Matches)+1) + matches[0] = labelName + for i, match := range filter.Matches { + matches[i+1] = match + } + return clause, matches, nil + + case NotIn: + target := "(?" + if len(filter.Matches) > 0 { + target += strings.Repeat(", ?", len(filter.Matches)-1) + } + target += ")" + subFilter := Filter{ + Field: filter.Field, + Op: NotExists, + } + existenceClause, subParams, err := l.getLabelFilter(index, subFilter, dbName) + if err != nil { + return "", nil, err + } + clause := fmt.Sprintf(`(%s) OR (lt%d.label = ? AND lt%d.value NOT IN %s)`, existenceClause, index, index, target) + matches := append(subParams, labelName) + for _, match := range filter.Matches { + matches = append(matches, match) + } + return clause, matches, nil + } + return "", nil, fmt.Errorf("unrecognized operator: %s", opString) +} + +func prepareComparisonParameters(op Op, target string) (string, float64, error) { + num, err := strconv.ParseFloat(target, 32) + if err != nil { + return "", 0, err + } + switch op { + case Lt: + return "<", num, nil + case Gt: + return ">", num, nil + } + return "", 0, fmt.Errorf("unrecognized operator when expecting '<' or '>': '%s'", op) +} + +func formatMatchTarget(filter Filter) string { + format := strictMatchFmt + if filter.Partial { + format = matchFmt + } + return formatMatchTargetWithFormatter(filter.Matches[0], format) +} + +func formatMatchTargetWithFormatter(match string, format string) string { + // To allow matches on the backslash itself, the character needs to be replaced first. + // Otherwise, it will undo the following replacements. + match = strings.ReplaceAll(match, `\`, `\\`) + match = strings.ReplaceAll(match, `_`, `\_`) + match = strings.ReplaceAll(match, `%`, `\%`) + return fmt.Sprintf(format, match) +} + +// toColumnName returns the column name corresponding to a field expressed as string slice +func toColumnName(s []string) string { + return db.Sanitize(strings.Join(s, ".")) +} + +// getField extracts the value of a field expressed as a string path from an unstructured object +func getField(a any, field string) (any, error) { + subFields := extractSubFields(field) + o, ok := a.(*unstructured.Unstructured) + if !ok { + return nil, fmt.Errorf("unexpected object type, expected unstructured.Unstructured: %v", a) + } + + var obj interface{} + var found bool + var err error + obj = o.Object + for i, subField := range subFields { + switch t := obj.(type) { + case map[string]interface{}: + subField = strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]") + obj, found, err = unstructured.NestedFieldNoCopy(t, subField) + if err != nil { + return nil, err + } + if !found { + // particularly with labels/annotation indexes, it is totally possible that some objects won't have these, + // so either we this is not an error state or it could be an error state with a type that callers can check for + return nil, nil + } + case []interface{}: + if strings.HasPrefix(subField, "[") && strings.HasSuffix(subField, "]") { + key, err := strconv.Atoi(strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]")) + if err != nil { + return nil, fmt.Errorf("[listoption indexer] failed to convert subfield [%s] to int in listoption index: %w", subField, err) + } + if key >= len(t) { + return nil, fmt.Errorf("[listoption indexer] given index is too large for slice of len %d", len(t)) + } + obj = fmt.Sprintf("%v", t[key]) + } else if i == len(subFields)-1 { + result := make([]string, len(t)) + for index, v := range t { + itemVal, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err) + } + itemStr, ok := itemVal[subField].(string) + if !ok { + return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err) + } + result[index] = itemStr + } + return result, nil + } + default: + return nil, fmt.Errorf("[listoption indexer] failed to parse subfields: %v", subFields) + } + } + return obj, nil +} + +func extractSubFields(fields string) []string { + subfields := make([]string, 0) + for _, subField := range subfieldRegex.FindAllString(fields, -1) { + subfields = append(subfields, strings.TrimSuffix(subField, ".")) + } + return subfields +} + +func isLabelFilter(f *Filter) bool { + return len(f.Field) >= 2 && f.Field[0] == "metadata" && f.Field[1] == "labels" +} + +func hasLabelFilter(filters []OrFilter) bool { + for _, outerFilter := range filters { + for _, filter := range outerFilter.Filters { + if isLabelFilter(&filter) { + return true + } + } + } + return false +} + +// toUnstructuredList turns a slice of unstructured objects into an unstructured.UnstructuredList +func toUnstructuredList(items []any) *unstructured.UnstructuredList { + objectItems := make([]map[string]any, len(items)) + result := &unstructured.UnstructuredList{ + Items: make([]unstructured.Unstructured, len(items)), + Object: map[string]interface{}{"items": objectItems}, + } + for i, item := range items { + result.Items[i] = *item.(*unstructured.Unstructured) + objectItems[i] = item.(*unstructured.Unstructured).Object + } + return result +} diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go new file mode 100644 index 00000000..fe7a038f --- /dev/null +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -0,0 +1,1478 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" +) + +func TestNewListOptionIndexer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "NewListOptionIndexer() with no errors returned, should return no error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + // create field table + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + // create field table indexes + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + loi, err := NewListOptionIndexer(fields, store, true) + assert.Nil(t, err) + assert.NotNil(t, loi) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from NewIndexer(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, false) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from Begin(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, false) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Exec() when creating fields table, should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from create-labels, should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Commit(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestListByOptions(t *testing.T) { + type testCase struct { + description string + listOptions ListOptions + partitions []partition.Partition + ns string + expectedCountStmt string + expectedCountStmtArgs []any + expectedStmt string + expectedStmtArgs []any + expectedList *unstructured.UnstructuredList + returnList []any + expectedContToken string + expectedErr error + } + + testObject := testStoreObject{Id: "something", Val: "a"} + unstrTestObjectMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(&testObject) + assert.Nil(t, err) + + var tests []testCase + tests = append(tests, testCase{ + description: "ListByOptions() with no errors returned, should not return an error", + listOptions: ListOptions{}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions() with an empty filter, should not return an error", + listOptions: ListOptions{ + Filters: []OrFilter{{[]Filter{}}}, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with ChunkSize set should set limit in prepared sql.Stmt", + listOptions: ListOptions{ChunkSize: 2}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ?`, + expectedStmtArgs: []interface{}{2}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Resume set should set offset in prepared sql.Stmt", + listOptions: ListOptions{Resume: "4"}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + OFFSET ?`, + expectedStmtArgs: []interface{}{4}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter should select where that filter is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter with Op set top NotEq should select where that filter is not true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" NOT LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter with Partial set to true should select where that partial match on that filter's value is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with multiple filters should select where any of those filters are true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"someothervalue"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"somethirdvalue"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + ((f."metadata.somefield" LIKE ? ESCAPE '\') OR (f."metadata.somefield" LIKE ? ESCAPE '\') OR (f."metadata.somefield" NOT LIKE ? ESCAPE '\')) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%", "%someothervalue%", "%somethirdvalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with multiple OrFilters set should select where all OrFilters contain one filter that is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"value1"}, + Op: Eq, + Partial: false, + }, + { + Field: []string{"status", "someotherfield"}, + Matches: []string{"value2"}, + Op: NotEq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"value3"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test4", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + ((f."metadata.somefield" = ?) OR (f."status.someotherfield" != ?)) AND + (f."metadata.somefield" = ?) AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"value1", "value2", "value3", "test4"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with labels filter should select the label in the prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "guard.cattle.io"}, + Matches: []string{"lodgepole"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test41", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"guard.cattle.io", "%lodgepole%", "test41"}, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions with two labels filters should use a self-join", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "cows"}, + Matches: []string{"milk"}, + Op: Eq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "horses"}, + Matches: []string{"saddles"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test42", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + LEFT OUTER JOIN "something_labels" lt2 ON o.key = lt2.key + WHERE + (lt1.label = ? AND lt1.value = ?) AND + (lt2.label = ? AND lt2.value = ?) AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"cows", "milk", "horses", "saddles", "test42"}, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions with a mix of one label and one non-label query can still self-join", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "cows"}, + Matches: []string{"butter"}, + Op: Eq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Matches: []string{"wheat"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test43", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value = ?) AND + (f."metadata.somefield" = ?) AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"cows", "butter", "wheat", "test43"}, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions with only one Sort.Field set should sort on that field only, in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}}, + Orders: []SortOrder{ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "test5", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.somefield" ASC`, + expectedStmtArgs: []any{"test5"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "sort one field descending", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}}, + Orders: []SortOrder{DESC}, + }, + }, + partitions: []partition.Partition{}, + ns: "test5a", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.somefield" DESC`, + expectedStmtArgs: []any{"test5a"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions sorting on two fields should sort on the first field in ascending order first and then sort on the second field in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" ASC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions sorting on two fields should sort on the first field in descending order first and then sort on the second field in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{DESC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" DESC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ListByOptions sorting when # fields != # sort orders should return an error", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{DESC, ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" DESC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: fmt.Errorf("sort fields length 2 != sort orders length 3"), + }) + + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.PageSize set should set limit to PageSize in prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + PageSize: 10, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ?`, + expectedStmtArgs: []any{10}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.Page and no PageSize set should not add anything to prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + Page: 2, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.Page and PageSize set limit to PageSize and offset to PageSize * (Page - 1) in prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + PageSize: 10, + Page: 2, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ? + OFFSET ?`, + expectedStmtArgs: []any{10, 10}, + + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE))`, + expectedCountStmtArgs: []any{}, + + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Namespace Partition should select only items where metadata.namespace is equal to Namespace and all other conditions are met in prepared sql.Stmt", + partitions: []partition.Partition{ + { + Namespace: "somens", + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ? AND FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somens"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a All Partition should select all items that meet all other conditions in prepared sql.Stmt", + partitions: []partition.Partition{ + { + All: true, + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Passthrough Partition should select all items that meet all other conditions prepared sql.Stmt", + partitions: []partition.Partition{ + { + Passthrough: true, + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Names Partition should select only items where metadata.name equals an items in Names and all other conditions are met in prepared sql.Stmt", + partitions: []partition.Partition{ + { + Names: sets.New[string]("someid", "someotherid"), + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.name" IN (?, ?)) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"someid", "someotherid"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + stmts := NewMockStmt(gomock.NewController(t)) + i := &Indexer{ + Store: store, + } + lii := &ListOptionIndexer{ + Indexer: i, + indexedFields: []string{"metadata.somefield", "status.someotherfield"}, + } + if test.description == "ListByOptions with multiple OrFilters set should select where all OrFilters contain one filter that is true in prepared sql.Stmt" { + fmt.Printf("stop here") + } + queryInfo, err := lii.constructQuery(test.listOptions, test.partitions, test.ns, "something") + if test.expectedErr != nil { + assert.Equal(t, test.expectedErr, err) + return + } + assert.Nil(t, err) + assert.Equal(t, test.expectedStmt, queryInfo.query) + if test.expectedStmtArgs == nil { + test.expectedStmtArgs = []any{} + } + assert.Equal(t, test.expectedStmtArgs, queryInfo.params) + assert.Equal(t, test.expectedCountStmt, queryInfo.countQuery) + assert.Equal(t, test.expectedCountStmtArgs, queryInfo.countParams) + + stmt := &sql.Stmt{} + rows := &sql.Rows{} + objType := reflect.TypeOf(testObject) + store.EXPECT().BeginTx(gomock.Any(), false).Return(txClient, nil) + txClient.EXPECT().Stmt(gomock.Any()).Return(stmts).AnyTimes() + store.EXPECT().Prepare(test.expectedStmt).Do(func(a ...any) { + fmt.Println(a) + }).Return(stmt) + if args := test.expectedStmtArgs; args != nil { + stmts.EXPECT().QueryContext(gomock.Any(), gomock.Any()).Return(rows, nil).AnyTimes() + } else if strings.Contains(test.expectedStmt, "LIMIT") { + stmts.EXPECT().QueryContext(gomock.Any(), args...).Return(rows, nil) + txClient.EXPECT().Stmt(gomock.Any()).Return(stmts) + stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil) + } else { + stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil) + } + store.EXPECT().GetType().Return(objType) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, objType, false).Return(test.returnList, nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + + if test.expectedCountStmt != "" { + store.EXPECT().Prepare(test.expectedCountStmt).Return(stmt) + //store.EXPECT().QueryForRows(context.TODO(), stmt, test.expectedCountStmtArgs...).Return(rows, nil) + store.EXPECT().ReadInt(rows).Return(len(test.expectedList.Items), nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + } + txClient.EXPECT().Commit() + list, total, contToken, err := lii.executeQuery(context.TODO(), queryInfo) + if test.expectedErr == nil { + assert.Nil(t, err) + } else { + assert.Equal(t, test.expectedErr, err) + } + assert.Equal(t, test.expectedList, list) + assert.Equal(t, len(test.expectedList.Items), total) + assert.Equal(t, test.expectedContToken, contToken) + }) + } +} + +func TestConstructQuery(t *testing.T) { + type testCase struct { + description string + listOptions ListOptions + partitions []partition.Partition + ns string + expectedCountStmt string + expectedCountStmtArgs []any + expectedStmt string + expectedStmtArgs []any + expectedErr error + } + + var tests []testCase + tests = append(tests, testCase{ + description: "TestConstructQuery: handles IN statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"somevalue"}, + Op: In, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.queryField1" IN (?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somevalue"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOT-IN statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"somevalue"}, + Op: NotIn, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.queryField1" NOT IN (?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somevalue"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles EXISTS statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Op: Exists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedErr: errors.New("NULL and NOT NULL tests aren't supported for non-label queries"), + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOT-EXISTS statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "queryField1"}, + Op: NotExists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedErr: errors.New("NULL and NOT NULL tests aren't supported for non-label queries"), + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles == statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelEqualFull"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelEqualFull", "somevalue"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles == statements for label statements, match partial", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelEqualPartial"}, + Matches: []string{"somevalue"}, + Op: Eq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelEqualPartial", "%somevalue%"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles != statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNotEqualFull"}, + Matches: []string{"somevalue"}, + Op: NotEq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value != ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNotEqualFull", "labelNotEqualFull", "somevalue"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles != statements for label statements, match partial", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNotEqualPartial"}, + Matches: []string{"somevalue"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value NOT LIKE ? ESCAPE '\')) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNotEqualPartial", "labelNotEqualPartial", "%somevalue%"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles multiple != statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "notEqual1"}, + Matches: []string{"value1"}, + Op: NotEq, + Partial: false, + }, + }, + }, + { + []Filter{ + { + Field: []string{"metadata", "labels", "notEqual2"}, + Matches: []string{"value2"}, + Op: NotEq, + Partial: false, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + LEFT OUTER JOIN "something_labels" lt2 ON o.key = lt2.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value != ?)) AND + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt2i1 ON o1.key = lt2i1.key + WHERE lt2i1.label = ?)) OR (lt2.label = ? AND lt2.value != ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"notEqual1", "notEqual1", "value1", "notEqual2", "notEqual2", "value2"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles IN statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelIN"}, + Matches: []string{"somevalue1", "someValue2"}, + Op: In, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value IN (?, ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelIN", "somevalue1", "someValue2"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOTIN statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNOTIN"}, + Matches: []string{"somevalue1", "someValue2"}, + Op: NotIn, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) OR (lt1.label = ? AND lt1.value NOT IN (?, ?))) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNOTIN", "labelNOTIN", "somevalue1", "someValue2"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles EXISTS statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelEXISTS"}, + Matches: []string{}, + Op: Exists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelEXISTS"}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "TestConstructQuery: handles NOTEXISTS statements for label statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "labelNOTEXISTS"}, + Matches: []string{}, + Op: NotExists, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (o.key NOT IN (SELECT o1.key FROM "something" o1 + JOIN "something_fields" f1 ON o1.key = f1.key + LEFT OUTER JOIN "something_labels" lt1i1 ON o1.key = lt1i1.key + WHERE lt1i1.label = ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"labelNOTEXISTS"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles LessThan statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "numericThing"}, + Matches: []string{"5"}, + Op: Lt, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value < ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"numericThing", float64(5)}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: handles GreaterThan statements", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "labels", "numericThing"}, + Matches: []string{"35"}, + Op: Gt, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + (lt1.label = ? AND lt1.value > ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"numericThing", float64(35)}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "multiple filters with a positive label test and a negative non-label test still outer-join", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "junta"}, + Matches: []string{"esther"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"golgi"}, + Op: NotEq, + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + WHERE + ((lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') OR (f."metadata.queryField1" NOT LIKE ? ESCAPE '\')) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"junta", "%esther%", "%golgi%"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "multiple filters and or-filters with a positive label test and a negative non-label test still outer-join and have correct AND/ORs", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "nectar"}, + Matches: []string{"stash"}, + Op: Eq, + Partial: true, + }, + { + Field: []string{"metadata", "queryField1"}, + Matches: []string{"landlady"}, + Op: NotEq, + Partial: false, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "labels", "lawn"}, + Matches: []string{"reba", "coil"}, + Op: In, + }, + { + Field: []string{"metadata", "queryField1"}, + Op: Gt, + Matches: []string{"2"}, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT DISTINCT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + LEFT OUTER JOIN "something_labels" lt1 ON o.key = lt1.key + LEFT OUTER JOIN "something_labels" lt2 ON o.key = lt2.key + WHERE + ((lt1.label = ? AND lt1.value LIKE ? ESCAPE '\') OR (f."metadata.queryField1" != ?)) AND + ((lt2.label = ? AND lt2.value IN (?, ?)) OR (f."metadata.queryField1" > ?)) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"nectar", "%stash%", "landlady", "lawn", "reba", "coil", float64(2)}, + expectedErr: nil, + }) + + tests = append(tests, testCase{ + description: "ConstructQuery: sorting when # fields < # sort orders should return an error", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}}, + Orders: []SortOrder{DESC, ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: "", + expectedStmtArgs: []any{}, + expectedErr: fmt.Errorf("sort fields length 2 != sort orders length 3"), + }) + + tests = append(tests, testCase{ + description: "ConstructQuery: sorting when # fields > # sort orders should return an error", + listOptions: ListOptions{ + Sort: Sort{ + Fields: [][]string{{"metadata", "somefield"}, {"status", "someotherfield"}, {"metadata", "labels", "a1"}, {"metadata", "labels", "a2"}}, + Orders: []SortOrder{DESC, ASC, ASC}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: "", + expectedStmtArgs: []any{}, + expectedErr: fmt.Errorf("sort fields length 4 != sort orders length 3"), + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + i := &Indexer{ + Store: store, + } + lii := &ListOptionIndexer{ + Indexer: i, + indexedFields: []string{"metadata.queryField1", "status.queryField2"}, + } + queryInfo, err := lii.constructQuery(test.listOptions, test.partitions, test.ns, "something") + if test.expectedErr != nil { + assert.Equal(t, test.expectedErr, err) + return + } + assert.Nil(t, err) + assert.Equal(t, test.expectedStmt, queryInfo.query) + assert.Equal(t, test.expectedStmtArgs, queryInfo.params) + assert.Equal(t, test.expectedCountStmt, queryInfo.countQuery) + assert.Equal(t, test.expectedCountStmtArgs, queryInfo.countParams) + }) + } +} diff --git a/pkg/sqlcache/informer/listoptions.go b/pkg/sqlcache/informer/listoptions.go new file mode 100644 index 00000000..71bf6f6e --- /dev/null +++ b/pkg/sqlcache/informer/listoptions.go @@ -0,0 +1,68 @@ +package informer + +type Op string + +const ( + Eq Op = "=" + NotEq Op = "!=" + Exists Op = "Exists" + NotExists Op = "NotExists" + In Op = "In" + NotIn Op = "NotIn" + Lt Op = "Lt" + Gt Op = "Gt" +) + +// SortOrder represents whether the list should be ascending or descending. +type SortOrder int + +const ( + // ASC stands for ascending order. + ASC SortOrder = iota + // DESC stands for descending (reverse) order. + DESC +) + +// ListOptions represents the query parameters that may be included in a list request. +type ListOptions struct { + ChunkSize int + Resume string + Filters []OrFilter + Sort Sort + Pagination Pagination +} + +// Filter represents a field to filter by. +// A subfield in an object is represented in a request query using . notation, e.g. 'metadata.name'. +// The subfield is internally represented as a slice, e.g. [metadata, name]. +// Complex subfields need to be expressed with square brackets, as in `metadata.labels[zombo.com/moose]`, +// but are mapped to the string slice ["metadata", "labels", "zombo.com/moose"] +// +// If more than one value is given for the `Match` field, we do an "IN ()" test +type Filter struct { + Field []string + Matches []string + Op Op + Partial bool +} + +// OrFilter represents a set of possible fields to filter by, where an item may match any filter in the set to be included in the result. +type OrFilter struct { + Filters []Filter +} + +// Sort represents the criteria to sort on. +// The subfield to sort by is represented in a request query using . notation, e.g. 'metadata.name'. +// The subfield is internally represented as a slice, e.g. [metadata, name]. +// The order is represented by prefixing the sort key by '-', e.g. sort=-metadata.name. +// e.g. To sort internal clusters first followed by clusters in alpha order: sort=-spec.internal,spec.displayName +type Sort struct { + Fields [][]string + Orders []SortOrder +} + +// Pagination represents how to return paginated results. +type Pagination struct { + PageSize int + Page int +} diff --git a/pkg/sqlcache/informer/shared_informer_hack.go b/pkg/sqlcache/informer/shared_informer_hack.go new file mode 100644 index 00000000..c11889c9 --- /dev/null +++ b/pkg/sqlcache/informer/shared_informer_hack.go @@ -0,0 +1,22 @@ +package informer + +import ( + "reflect" + "unsafe" +) + +// UnsafeSet replaces the passed object's field value with the passed value. +func UnsafeSet(object any, field string, value any) { + rs := reflect.ValueOf(object).Elem() + rf := rs.FieldByName(field) + wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + wrf.Set(reflect.ValueOf(value)) +} + +// UnsafeGet returns the value of the passed object's for the passed field. +func UnsafeGet(object any, field string) any { + rs := reflect.ValueOf(object).Elem() + rf := rs.FieldByName(field) + wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + return wrf.Interface() +} diff --git a/pkg/sqlcache/informer/shared_informer_test.go b/pkg/sqlcache/informer/shared_informer_test.go new file mode 100644 index 00000000..bd143647 --- /dev/null +++ b/pkg/sqlcache/informer/shared_informer_test.go @@ -0,0 +1,325 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "fmt" + "k8s.io/client-go/tools/cache" + "strings" + "sync" + "testing" + "time" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + fcache "k8s.io/client-go/tools/cache/testing" + testingclock "k8s.io/utils/clock/testing" +) + +type testListener struct { + lock sync.RWMutex + resyncPeriod time.Duration + expectedItemNames sets.Set[string] + receivedItemNames []string + name string +} + +func newTestListener(name string, resyncPeriod time.Duration, expected ...string) *testListener { + l := &testListener{ + resyncPeriod: resyncPeriod, + expectedItemNames: sets.New[string](expected...), + name: name, + } + return l +} + +func (l *testListener) OnAdd(obj interface{}, isInInitialList bool) { + l.handle(obj) +} + +func (l *testListener) OnUpdate(old, new interface{}) { + l.handle(new) +} + +func (l *testListener) OnDelete(obj interface{}) { +} + +func (l *testListener) handle(obj interface{}) { + key, _ := cache.MetaNamespaceKeyFunc(obj) + fmt.Printf("%s: handle: %v\n", l.name, key) + l.lock.Lock() + defer l.lock.Unlock() + + objectMeta, _ := meta.Accessor(obj) + l.receivedItemNames = append(l.receivedItemNames, objectMeta.GetName()) +} + +func (l *testListener) ok() bool { + fmt.Println("polling") + err := wait.PollImmediate(100*time.Millisecond, 2*time.Second, func() (bool, error) { + if l.satisfiedExpectations() { + return true, nil + } + return false, nil + }) + if err != nil { + return false + } + + // wait just a bit to allow any unexpected stragglers to come in + fmt.Println("sleeping") + time.Sleep(1 * time.Second) + fmt.Println("final check") + return l.satisfiedExpectations() +} + +func (l *testListener) satisfiedExpectations() bool { + l.lock.RLock() + defer l.lock.RUnlock() + + return sets.New[string](l.receivedItemNames...).Equal(l.expectedItemNames) +} + +func TestListenerResyncPeriods(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}) + + // create the shared informer and resync every 1s + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + clock := testingclock.NewFakeClock(time.Now()) + UnsafeSet(informer, "clock", clock) + UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock) + + // listener 1, never resync + listener1 := newTestListener("listener1", 0, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener1, listener1.resyncPeriod) + + // listener 2, resync every 2s + listener2 := newTestListener("listener2", 2*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener2, listener2.resyncPeriod) + + // listener 3, resync every 3s + listener3 := newTestListener("listener3", 3*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener3, listener3.resyncPeriod) + listeners := []*testListener{listener1, listener2, listener3} + + stop := make(chan struct{}) + defer close(stop) + + go informer.Run(stop) + + // ensure all listeners got the initial List + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + // reset + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + // advance so listener2 gets a resync + clock.Step(2 * time.Second) + + // make sure listener2 got the resync + if !listener2.ok() { + t.Errorf("%s: expected %v, got %v", listener2.name, listener2.expectedItemNames, listener2.receivedItemNames) + } + + // wait a bit to give errant items a chance to go to 1 and 3 + time.Sleep(1 * time.Second) + + // make sure listeners 1 and 3 got nothing + if len(listener1.receivedItemNames) != 0 { + t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames)) + } + if len(listener3.receivedItemNames) != 0 { + t.Errorf("listener3: should not have resynced (got %d)", len(listener3.receivedItemNames)) + } + + // reset + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + // advance so listener3 gets a resync + clock.Step(1 * time.Second) + + // make sure listener3 got the resync + if !listener3.ok() { + t.Errorf("%s: expected %v, got %v", listener3.name, listener3.expectedItemNames, listener3.receivedItemNames) + } + + // wait a bit to give errant items a chance to go to 1 and 2 + time.Sleep(1 * time.Second) + + // make sure listeners 1 and 2 got nothing + if len(listener1.receivedItemNames) != 0 { + t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames)) + } + if len(listener2.receivedItemNames) != 0 { + t.Errorf("listener2: should not have resynced (got %d)", len(listener2.receivedItemNames)) + } +} + +// verify that https://github.com/kubernetes/kubernetes/issues/59822 is fixed +func TestSharedInformerInitializationRace(t *testing.T) { + source := fcache.NewFakeControllerSource() + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + listener := newTestListener("raceListener", 0) + + stop := make(chan struct{}) + go informer.AddEventHandlerWithResyncPeriod(listener, listener.resyncPeriod) + go informer.Run(stop) + close(stop) +} + +// TestSharedInformerWatchDisruption simulates a watch that was closed +// with updates to the store during that time. We ensure that handlers with +// resync and no resync see the expected state. +func TestSharedInformerWatchDisruption(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}}) + + // create the shared informer and resync every 1s + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + clock := testingclock.NewFakeClock(time.Now()) + UnsafeSet(informer, "clock", clock) + UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock) + + // listener, never resync + listenerNoResync := newTestListener("listenerNoResync", 0, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listenerNoResync, listenerNoResync.resyncPeriod) + + listenerResync := newTestListener("listenerResync", 1*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listenerResync, listenerResync.resyncPeriod) + listeners := []*testListener{listenerNoResync, listenerResync} + + stop := make(chan struct{}) + defer close(stop) + + go informer.Run(stop) + + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + // Add pod3, bump pod2 but don't broadcast it, so that the change will be seen only on relist + source.AddDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod3", UID: "pod3", ResourceVersion: "3"}}) + source.ModifyDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "4"}}) + + // Ensure that nobody saw any changes + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + listenerNoResync.expectedItemNames = sets.New[string]("pod2", "pod3") + listenerResync.expectedItemNames = sets.New[string]("pod1", "pod2", "pod3") + + // This calls shouldSync, which deletes noResync from the list of syncingListeners + clock.Step(1 * time.Second) + + // Simulate a connection loss (or even just a too-old-watch) + source.ResetWatch() + + // Wait long enough for the reflector to exit and the backoff function to start waiting + // on the fake clock, otherwise advancing the fake clock will have no effect. + // TODO: Make this deterministic by counting the number of waiters on FakeClock + time.Sleep(10 * time.Millisecond) + + // Advance the clock to cause the backoff wait to expire. + clock.Step(1601 * time.Millisecond) + + // Wait long enough for backoff to invoke ListWatch a second time and distribute events + // to listeners. + time.Sleep(10 * time.Millisecond) + + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } +} + +func TestSharedInformerErrorHandling(t *testing.T) { + source := fcache.NewFakeControllerSource() + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}) + source.ListError = fmt.Errorf("Access Denied") + + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + errCh := make(chan error) + _ = informer.SetWatchErrorHandler(func(_ *cache.Reflector, err error) { + errCh <- err + }) + + stop := make(chan struct{}) + go informer.Run(stop) + + select { + case err := <-errCh: + if !strings.Contains(err.Error(), "Access Denied") { + t.Errorf("Expected 'Access Denied' error. Actual: %v", err) + } + case <-time.After(time.Second): + t.Errorf("Timeout waiting for error handler call") + } + close(stop) +} + +func TestSharedInformerTransformer(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}}) + + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + informer.SetTransform(func(obj interface{}) (interface{}, error) { + if pod, ok := obj.(*v1.Pod); ok { + name := pod.GetName() + + if upper := strings.ToUpper(name); upper != name { + copied := pod.DeepCopyObject().(*v1.Pod) + copied.SetName(upper) + return copied, nil + } + } + return obj, nil + }) + + listenerTransformer := newTestListener("listenerTransformer", 0, "POD1", "POD2") + informer.AddEventHandler(listenerTransformer) + + stop := make(chan struct{}) + go informer.Run(stop) + defer close(stop) + + if !listenerTransformer.ok() { + t.Errorf("%s: expected %v, got %v", listenerTransformer.name, listenerTransformer.expectedItemNames, listenerTransformer.receivedItemNames) + } +} diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go new file mode 100644 index 00000000..c172320f --- /dev/null +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -0,0 +1,347 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: Store) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer Store +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/lasso/pkg/cache/sql/db" + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockStore) Add(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockStoreMockRecorder) Add(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockStore)(nil).Add), arg0) +} + +// BeginTx mocks base method. +func (m *MockStore) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockStoreMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockStore)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockStore) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockStoreMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockStore)(nil).CloseStmt), arg0) +} + +// Delete mocks base method. +func (m *MockStore) Delete(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockStoreMockRecorder) Delete(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete), arg0) +} + +// Get mocks base method. +func (m *MockStore) Get(arg0 any) (any, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Get indicates an expected call of Get. +func (mr *MockStoreMockRecorder) Get(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), arg0) +} + +// GetByKey mocks base method. +func (m *MockStore) GetByKey(arg0 string) (any, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByKey", arg0) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetByKey indicates an expected call of GetByKey. +func (mr *MockStoreMockRecorder) GetByKey(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByKey", reflect.TypeOf((*MockStore)(nil).GetByKey), arg0) +} + +// GetName mocks base method. +func (m *MockStore) GetName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetName") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetName indicates an expected call of GetName. +func (mr *MockStoreMockRecorder) GetName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockStore)(nil).GetName)) +} + +// GetShouldEncrypt mocks base method. +func (m *MockStore) GetShouldEncrypt() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetShouldEncrypt") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GetShouldEncrypt indicates an expected call of GetShouldEncrypt. +func (mr *MockStoreMockRecorder) GetShouldEncrypt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShouldEncrypt", reflect.TypeOf((*MockStore)(nil).GetShouldEncrypt)) +} + +// GetType mocks base method. +func (m *MockStore) GetType() reflect.Type { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetType") + ret0, _ := ret[0].(reflect.Type) + return ret0 +} + +// GetType indicates an expected call of GetType. +func (mr *MockStoreMockRecorder) GetType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetType", reflect.TypeOf((*MockStore)(nil).GetType)) +} + +// List mocks base method. +func (m *MockStore) List() []any { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]any) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockStoreMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStore)(nil).List)) +} + +// ListKeys mocks base method. +func (m *MockStore) ListKeys() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListKeys") + ret0, _ := ret[0].([]string) + return ret0 +} + +// ListKeys indicates an expected call of ListKeys. +func (mr *MockStoreMockRecorder) ListKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKeys", reflect.TypeOf((*MockStore)(nil).ListKeys)) +} + +// Prepare mocks base method. +func (m *MockStore) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockStoreMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockStore)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockStore) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockStoreMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockStore)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockStore) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockStoreMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockStore)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockStore) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockStoreMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockStore)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockStore) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockStoreMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockStore)(nil).ReadStrings), arg0) +} + +// RegisterAfterDelete mocks base method. +func (m *MockStore) RegisterAfterDelete(arg0 func(string, db.TXClient) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterDelete", arg0) +} + +// RegisterAfterDelete indicates an expected call of RegisterAfterDelete. +func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0) +} + +// RegisterAfterUpsert mocks base method. +func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, db.TXClient) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterUpsert", arg0) +} + +// RegisterAfterUpsert indicates an expected call of RegisterAfterUpsert. +func (mr *MockStoreMockRecorder) RegisterAfterUpsert(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterUpsert", reflect.TypeOf((*MockStore)(nil).RegisterAfterUpsert), arg0) +} + +// Replace mocks base method. +func (m *MockStore) Replace(arg0 []any, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Replace", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Replace indicates an expected call of Replace. +func (mr *MockStoreMockRecorder) Replace(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Replace", reflect.TypeOf((*MockStore)(nil).Replace), arg0, arg1) +} + +// Resync mocks base method. +func (m *MockStore) Resync() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Resync") + ret0, _ := ret[0].(error) + return ret0 +} + +// Resync indicates an expected call of Resync. +func (mr *MockStoreMockRecorder) Resync() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resync", reflect.TypeOf((*MockStore)(nil).Resync)) +} + +// Update mocks base method. +func (m *MockStore) Update(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockStoreMockRecorder) Update(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStore)(nil).Update), arg0) +} diff --git a/pkg/sqlcache/informer/store_mocks_test.go b/pkg/sqlcache/informer/store_mocks_test.go new file mode 100644 index 00000000..fcb1a90d --- /dev/null +++ b/pkg/sqlcache/informer/store_mocks_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/store (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/lasso/pkg/cache/sql/db" + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/informer/tx_mocks_test.go b/pkg/sqlcache/informer/tx_mocks_test.go new file mode 100644 index 00000000..e482df9b --- /dev/null +++ b/pkg/sqlcache/informer/tx_mocks_test.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} diff --git a/pkg/sqlcache/integration_test.go b/pkg/sqlcache/integration_test.go new file mode 100644 index 00000000..3c752967 --- /dev/null +++ b/pkg/sqlcache/integration_test.go @@ -0,0 +1,365 @@ +package sql + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/suite" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/envtest" + + "github.com/rancher/lasso/pkg/cache/sql/informer" + "github.com/rancher/lasso/pkg/cache/sql/informer/factory" + "github.com/rancher/lasso/pkg/cache/sql/partition" +) + +const testNamespace = "sql-test" + +var defaultPartition = partition.Partition{ + All: true, +} + +type IntegrationSuite struct { + suite.Suite + testEnv envtest.Environment + clientset kubernetes.Clientset + restCfg rest.Config +} + +func (i *IntegrationSuite) SetupSuite() { + i.testEnv = envtest.Environment{} + restCfg, err := i.testEnv.Start() + i.Require().NoError(err, "error when starting env test - this is likely because setup-envtest wasn't done. Check the README for more information") + i.restCfg = *restCfg + clientset, err := kubernetes.NewForConfig(restCfg) + i.Require().NoError(err) + i.clientset = *clientset + testNs := v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: testNamespace, + }, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, err = i.clientset.CoreV1().Namespaces().Create(ctx, &testNs, metav1.CreateOptions{}) + i.Require().NoError(err) +} + +func (i *IntegrationSuite) TearDownSuite() { + err := i.testEnv.Stop() + i.Require().NoError(err) +} + +func (i *IntegrationSuite) TestSQLCacheFilters() { + fields := [][]string{{`metadata`, `annotations[somekey]`}} + require := i.Require() + configMapWithAnnotations := func(name string, annotations map[string]string) v1.ConfigMap { + return v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: testNamespace, + Annotations: annotations, + }, + } + } + createConfigMaps := func(configMaps ...v1.ConfigMap) { + for _, configMap := range configMaps { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + configMapClient := i.clientset.CoreV1().ConfigMaps(testNamespace) + _, err := configMapClient.Create(ctx, &configMap, metav1.CreateOptions{}) + require.NoError(err) + // avoiding defer in a for loop + cancel() + } + } + + // we create some configmaps before the cache starts and some after so that we can test the initial list and + // subsequent watches to make sure both work + // matches the filter for somekey == somevalue + matches := configMapWithAnnotations("matches-filter", map[string]string{"somekey": "somevalue"}) + // partial match for somekey == somevalue (different suffix) + partialMatches := configMapWithAnnotations("partial-matches", map[string]string{"somekey": "somevaluehere"}) + specialCharacterMatch := configMapWithAnnotations("special-character-matches", map[string]string{"somekey": "c%%l_value"}) + backSlashCharacterMatch := configMapWithAnnotations("backslash-character-matches", map[string]string{"somekey": `my\windows\path`}) + createConfigMaps(matches, partialMatches, specialCharacterMatch, backSlashCharacterMatch) + + cache, cacheFactory, err := i.createCacheAndFactory(fields, nil) + require.NoError(err) + defer cacheFactory.Reset() + + // doesn't match the filter for somekey == somevalue + notMatches := configMapWithAnnotations("not-matches-filter", map[string]string{"somekey": "notequal"}) + // has no annotations, shouldn't match any filter + missing := configMapWithAnnotations("missing", nil) + createConfigMaps(notMatches, missing) + + configMapNames := []string{matches.Name, partialMatches.Name, notMatches.Name, missing.Name, specialCharacterMatch.Name, backSlashCharacterMatch.Name} + err = i.waitForCacheReady(configMapNames, testNamespace, cache) + require.NoError(err) + + orFiltersForFilters := func(filters ...informer.Filter) []informer.OrFilter { + return []informer.OrFilter{ + { + Filters: filters, + }, + } + } + tests := []struct { + name string + filters []informer.OrFilter + wantNames []string + }{ + { + name: "matches filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: false, + }), + wantNames: []string{"matches-filter"}, + }, + { + name: "partial matches filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"matches-filter", "partial-matches"}, + }, + { + name: "no matches for filter with underscore as it is interpreted literally", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalu_"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: nil, + }, + { + name: "no matches for filter with percent sign as it is interpreted literally", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalu%"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: nil, + }, + { + name: "match with special characters", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"c%%l_value"}, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"special-character-matches"}, + }, + { + name: "match with literal backslash character", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{`my\windows\path`}, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"backslash-character-matches"}, + }, + { + name: "not eq filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.NotEq, + Partial: false, + }), + wantNames: []string{"partial-matches", "not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"}, + }, + { + name: "partial not eq filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.NotEq, + Partial: true, + }), + wantNames: []string{"not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"}, + }, + { + name: "multiple or filters match", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }, + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"notequal"}, + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{"matches-filter", "partial-matches", "not-matches-filter"}, + }, + { + name: "or filters on different fields", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }, + informer.Filter{ + Field: []string{`metadata`, `name`}, + Matches: []string{"missing"}, + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{"matches-filter", "partial-matches", "missing"}, + }, + { + name: "and filters, both must match", + filters: []informer.OrFilter{ + { + Filters: []informer.Filter{ + { + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"somevalue"}, + Op: informer.Eq, + Partial: true, + }, + }, + }, + { + Filters: []informer.Filter{ + { + Field: []string{`metadata`, `name`}, + Matches: []string{"matches-filter"}, + Op: informer.Eq, + Partial: false, + }, + }, + }, + }, + wantNames: []string{"matches-filter"}, + }, + { + name: "no matches", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Matches: []string{"valueNotRepresented"}, + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{}, + }, + } + + for _, test := range tests { + test := test + i.Run(test.name, func() { + options := informer.ListOptions{ + Filters: test.filters, + } + partitions := []partition.Partition{defaultPartition} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + cfgMaps, total, continueToken, err := cache.ListByOptions(ctx, options, partitions, testNamespace) + i.Require().NoError(err) + // since there's no additional pages, the continue token should be empty + i.Require().Equal("", continueToken) + i.Require().NotNil(cfgMaps) + // assert instead of require so that we can see the full evaluation of # of resources returned + i.Assert().Equal(len(test.wantNames), total) + i.Assert().Len(cfgMaps.Items, len(test.wantNames)) + requireNames := sets.Set[string]{} + requireNames.Insert(test.wantNames...) + gotNames := sets.Set[string]{} + for _, configMap := range cfgMaps.Items { + gotNames.Insert(configMap.GetName()) + } + i.Require().True(requireNames.Equal(gotNames), "wanted %v, got %v", requireNames, gotNames) + }) + } +} + +func (i *IntegrationSuite) createCacheAndFactory(fields [][]string, transformFunc cache.TransformFunc) (*factory.Cache, *factory.CacheFactory, error) { + cacheFactory, err := factory.NewCacheFactory() + if err != nil { + return nil, nil, fmt.Errorf("unable to make factory: %w", err) + } + dynamicClient, err := dynamic.NewForConfig(&i.restCfg) + if err != nil { + return nil, nil, fmt.Errorf("unable to make dynamicClient: %w", err) + } + configMapGVK := schema.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "ConfigMap", + } + configMapGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "configmaps", + } + dynamicResource := dynamicClient.Resource(configMapGVR).Namespace(testNamespace) + cache, err := cacheFactory.CacheFor(fields, transformFunc, dynamicResource, configMapGVK, true, true) + if err != nil { + return nil, nil, fmt.Errorf("unable to make cache: %w", err) + } + return &cache, cacheFactory, nil +} + +func (i *IntegrationSuite) waitForCacheReady(readyResourceNames []string, namespace string, cache *factory.Cache) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + return wait.PollUntilContextCancel(ctx, time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) { + var options informer.ListOptions + partitions := []partition.Partition{defaultPartition} + cacheCtx, cacheCancel := context.WithTimeout(ctx, time.Second*5) + defer cacheCancel() + currentResources, total, _, err := cache.ListByOptions(cacheCtx, options, partitions, namespace) + if err != nil { + // note that we don't return the error since that would stop the polling + return false, nil + } + if total != len(readyResourceNames) { + return false, nil + } + wantNames := sets.Set[string]{} + wantNames.Insert(readyResourceNames...) + gotNames := sets.Set[string]{} + for _, current := range currentResources.Items { + name := current.GetName() + if !wantNames.Has(name) { + return true, fmt.Errorf("got resource %s which wasn't expected", name) + } + gotNames.Insert(name) + } + return wantNames.Equal(gotNames), nil + }) +} + +func TestIntegrationSuite(t *testing.T) { + suite.Run(t, new(IntegrationSuite)) +} diff --git a/pkg/sqlcache/partition/partition.go b/pkg/sqlcache/partition/partition.go new file mode 100644 index 00000000..dd21a60e --- /dev/null +++ b/pkg/sqlcache/partition/partition.go @@ -0,0 +1,24 @@ +/* +Package partition represents listing parameters. They can be used to specify which namespaces a caller would like included +in a response, or which specific objects they are looking for. +*/ +package partition + +import ( + "k8s.io/apimachinery/pkg/util/sets" +) + +// Partition represents filtering of a request's results +type Partition struct { + // if true, do not apply any filtering, return all results. Overrides all other fields + Passthrough bool + + // if non-empty, only resources in the specified namespaces will be returned + Namespace string + + // if true, return all results, while still honoring Namespace. Overrides Names + All bool + + // if non-empty, only resources with matching names will be returned + Names sets.Set[string] +} diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go new file mode 100644 index 00000000..489a1c1c --- /dev/null +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: TXClient,Rows) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows +// + +// Package store is a generated GoMock package. +package store + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go new file mode 100644 index 00000000..a89d57bf --- /dev/null +++ b/pkg/sqlcache/store/store.go @@ -0,0 +1,360 @@ +/* +Package store contains the sql backed store. It persists objects to a sqlite database. +*/ +package store + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/rancher/lasso/pkg/cache/sql/db" + "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + "github.com/rancher/lasso/pkg/log" + "k8s.io/client-go/tools/cache" + _ "modernc.org/sqlite" +) + +const ( + upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)` + deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?` + getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?` + listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"` + listKeysStmtFmt = `SELECT key FROM "%s"` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" ( + key TEXT UNIQUE NOT NULL PRIMARY KEY, + object BLOB, + objectnonce BLOB, + dekid INTEGER + )` +) + +// Store is a SQLite-backed cache.Store +type Store struct { + DBClient + + name string + typ reflect.Type + keyFunc cache.KeyFunc + shouldEncrypt bool + + upsertQuery string + deleteQuery string + getQuery string + listQuery string + listKeysQuery string + + upsertStmt *sql.Stmt + deleteStmt *sql.Stmt + getStmt *sql.Stmt + listStmt *sql.Stmt + listKeysStmt *sql.Stmt + + afterUpsert []func(key string, obj any, tx db.TXClient) error + afterDelete []func(key string, tx db.TXClient) error +} + +// Test that Store implements cache.Indexer +var _ cache.Store = (*Store)(nil) + +type DBClient interface { + BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) + Prepare(stmt string) *sql.Stmt + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows db.Rows) ([]string, error) + ReadInt(rows db.Rows) (int, error) + Upsert(tx db.TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error + CloseStmt(closable db.Closable) error +} + +// NewStore creates a SQLite-backed cache.Store for objects of the given example type +func NewStore(example any, keyFunc cache.KeyFunc, c DBClient, shouldEncrypt bool, name string) (*Store, error) { + s := &Store{ + name: name, + typ: reflect.TypeOf(example), + DBClient: c, + keyFunc: keyFunc, + shouldEncrypt: shouldEncrypt, + afterUpsert: []func(key string, obj any, tx db.TXClient) error{}, + afterDelete: []func(key string, tx db.TXClient) error{}, + } + + // once multiple informerfactories are needed, this can accept the case where table already exists error is received + txC, err := s.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + dbName := db.Sanitize(s.name) + createTableQuery := fmt.Sprintf(createTableFmt, dbName) + err = txC.Exec(createTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createTableQuery, Err: err} + } + + err = txC.Commit() + if err != nil { + return nil, err + } + + s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName) + s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName) + s.getQuery = fmt.Sprintf(getStmtFmt, dbName) + s.listQuery = fmt.Sprintf(listStmtFmt, dbName) + s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName) + + s.upsertStmt = s.Prepare(s.upsertQuery) + s.deleteStmt = s.Prepare(s.deleteQuery) + s.getStmt = s.Prepare(s.getQuery) + s.listStmt = s.Prepare(s.listQuery) + s.listKeysStmt = s.Prepare(s.listKeysQuery) + + return s, nil +} + +/* Core methods */ +// upsert saves an obj with its key, or updates key with obj if it exists in this Store +func (s *Store) upsert(key string, obj any) error { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + err = s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return &db.QueryError{QueryString: s.upsertQuery, Err: err} + } + + err = s.runAfterUpsert(key, obj, tx) + if err != nil { + return err + } + + return tx.Commit() +} + +// deleteByKey deletes the object associated with key, if it exists in this Store +func (s *Store) deleteByKey(key string) error { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + err = tx.StmtExec(tx.Stmt(s.deleteStmt), key) + if err != nil { + return &db.QueryError{QueryString: s.deleteQuery, Err: err} + } + + err = s.runAfterDelete(key, tx) + if err != nil { + return err + } + + return tx.Commit() +} + +// GetByKey returns the object associated with the given object's key +func (s *Store) GetByKey(key string) (item any, exists bool, err error) { + rows, err := s.QueryForRows(context.TODO(), s.getStmt, key) + if err != nil { + return nil, false, &db.QueryError{QueryString: s.getQuery, Err: err} + } + result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt) + if err != nil { + return nil, false, err + } + + if len(result) == 0 { + return nil, false, nil + } + + return result[0], true, nil +} + +/* Satisfy cache.Store */ + +// Add saves an obj, or updates it if it exists in this Store +func (s *Store) Add(obj any) error { + key, err := s.keyFunc(obj) + if err != nil { + return err + } + + err = s.upsert(key, obj) + if err != nil { + log.Errorf("Error in Store.Add for type %v: %v", s.name, err) + return err + } + return nil +} + +// Update saves an obj, or updates it if it exists in this Store +func (s *Store) Update(obj any) error { + return s.Add(obj) +} + +// Delete deletes the given object, if it exists in this Store +func (s *Store) Delete(obj any) error { + key, err := s.keyFunc(obj) + if err != nil { + return err + } + err = s.deleteByKey(key) + if err != nil { + log.Errorf("Error in Store.Delete for type %v: %v", s.name, err) + return err + } + return nil +} + +// List returns a list of all the currently known objects +// Note: I/O errors will panic this function, as the interface signature does not allow returning errors +func (s *Store) List() []any { + rows, err := s.QueryForRows(context.TODO(), s.listStmt) + if err != nil { + panic(&db.QueryError{QueryString: s.listQuery, Err: err}) + } + result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt) + if err != nil { + panic(fmt.Errorf("error in Store.List: %w", err)) + } + return result +} + +// ListKeys returns a list of all the keys currently in this Store +// Note: Atm it doesn't appear returning nil in the case of an error has any detrimental effects. An error is not +// uncommon enough nor does it appear to necessitate a panic. +func (s *Store) ListKeys() []string { + rows, err := s.QueryForRows(context.TODO(), s.listKeysStmt) + if err != nil { + fmt.Printf("Unexpected error in store.ListKeys: while executing query: %s got error: %v", s.listKeysQuery, err) + return []string{} + } + result, err := s.ReadStrings(rows) + if err != nil { + fmt.Printf("Unexpected error in store.ListKeys: %v\n", err) + return []string{} + } + return result +} + +// Get returns the object with the same key as obj +func (s *Store) Get(obj any) (item any, exists bool, err error) { + key, err := s.keyFunc(obj) + if err != nil { + return nil, false, err + } + + return s.GetByKey(key) +} + +// Replace will delete the contents of the Store, using instead the given list +func (s *Store) Replace(objects []any, _ string) error { + objectMap := map[string]any{} + + for _, object := range objects { + key, err := s.keyFunc(object) + if err != nil { + return err + } + objectMap[key] = object + } + return s.replaceByKey(objectMap) +} + +// replaceByKey will delete the contents of the Store, using instead the given key to obj map +func (s *Store) replaceByKey(objects map[string]any) error { + txC, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + txCListKeys := txC.Stmt(s.listKeysStmt) + + rows, err := s.QueryForRows(context.TODO(), txCListKeys) + if err != nil { + return err + } + keys, err := s.ReadStrings(rows) + if err != nil { + return err + } + + for _, key := range keys { + err = txC.StmtExec(txC.Stmt(s.deleteStmt), key) + if err != nil { + return err + } + err = s.runAfterDelete(key, txC) + if err != nil { + return err + } + } + + for key, obj := range objects { + err = s.Upsert(txC, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return err + } + err = s.runAfterUpsert(key, obj, txC) + if err != nil { + return err + } + } + + return txC.Commit() +} + +// Resync is a no-op and is deprecated +func (s *Store) Resync() error { + return nil +} + +/* Utilities */ + +// RegisterAfterUpsert registers a func to be called after each upsert +func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC db.TXClient) error) { + s.afterUpsert = append(s.afterUpsert, f) +} + +func (s *Store) GetName() string { + return s.name +} + +func (s *Store) GetShouldEncrypt() bool { + return s.shouldEncrypt +} + +func (s *Store) GetType() reflect.Type { + return s.typ +} + +// keep +// runAfterUpsert executes functions registered to run after upsert +func (s *Store) runAfterUpsert(key string, obj any, txC db.TXClient) error { + for _, f := range s.afterUpsert { + err := f(key, obj, txC) + if err != nil { + return err + } + } + return nil +} + +// RegisterAfterDelete registers a func to be called after each deletion +func (s *Store) RegisterAfterDelete(f func(key string, txC db.TXClient) error) { + s.afterDelete = append(s.afterDelete, f) +} + +// keep +// runAfterDelete executes functions registered to run after upsert +func (s *Store) runAfterDelete(key string, txC db.TXClient) error { + for _, f := range s.afterDelete { + err := f(key, txC) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/sqlcache/store/store_mocks_test.go b/pkg/sqlcache/store/store_mocks_test.go new file mode 100644 index 00000000..1ae793c8 --- /dev/null +++ b/pkg/sqlcache/store/store_mocks_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/store (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient +// + +// Package store is a generated GoMock package. +package store + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/lasso/pkg/cache/sql/db" + transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go new file mode 100644 index 00000000..dba11367 --- /dev/null +++ b/pkg/sqlcache/store/store_test.go @@ -0,0 +1,646 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package store + +// Mocks for this test are generated with the following command. +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/rancher/lasso/pkg/cache/sql/db" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func testStoreKeyFunc(obj interface{}) (string, error) { + return obj.(testStoreObject).Id, nil +} + +type testStoreObject struct { + Id string + Val string +} + +func TestAdd(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Add with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Add(testObject) + assert.Nil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Commit().Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + + var count int + store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx db.TXClient) error { + count++ + return nil + }) + err := store.Add(testObject) + assert.Nil(t, err) + assert.Equal(t, count, 1) + }, + }) + + tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + return fmt.Errorf("error") + }) + err := store.Add(testObject) + assert.NotNil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Add with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + store := SetupStore(t, c, shouldEncrypt) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Update updates the given object in the accumulator associated with the given object's key +func TestUpdate(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Update with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Update(testObject) + assert.Nil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + + var count int + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + count++ + return nil + }) + err := store.Update(testObject) + assert.Nil(t, err) + assert.Equal(t, count, 1) + }, + }) + + tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + return fmt.Errorf("error") + }) + err := store.Update(testObject) + assert.NotNil(t, err) + }, + }) + + tests = append(tests, testCase{description: "Update with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + store := SetupStore(t, c, shouldEncrypt) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Delete deletes the given object from the accumulator associated with the given object's key +func TestDelete(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Delete with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Delete(testObject) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + // tx.EXPECT(). + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("error")) + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// List returns a list of all the currently non-empty accumulators +func TestList(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + tests = append(tests, testCase{description: "List with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + items := store.List() + assert.Len(t, items, 0) + }, + }) + tests = append(tests, testCase{description: "List with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + fakeItemsToReturn := []any{"something1", 2, false} + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(fakeItemsToReturn, nil) + items := store.List() + assert.Equal(t, fakeItemsToReturn, items) + }, + }) + tests = append(tests, testCase{description: "List with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + defer func() { + recover() + }() + _ = store.List() + assert.Fail(t, "Store list should panic when ReadObjects returns an error") + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// ListKeys returns a list of all the keys currently associated with non-empty accumulators +func TestListKeys(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListKeys with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{"a", "b", "c"}, nil) + keys := store.ListKeys() + assert.Len(t, keys, 3) + }, + }) + + tests = append(tests, testCase{description: "ListKeys with DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + keys := store.ListKeys() + assert.Len(t, keys, 0) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Get returns the accumulator associated with the given object's key +func TestGet(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "Get with no DB Client errors and object exists", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil) + item, exists, err := store.Get(testObject) + assert.Nil(t, err) + assert.Equal(t, item, testObject) + assert.True(t, exists) + }, + }) + tests = append(tests, testCase{description: "Get with no DB Client errors and object does not exist", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + item, exists, err := store.Get(testObject) + assert.Nil(t, err) + assert.Equal(t, item, nil) + assert.False(t, exists) + }, + }) + tests = append(tests, testCase{description: "Get with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + _, _, err := store.Get(testObject) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// GetByKey returns the accumulator associated with the given key +func TestGetByKey(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item exists", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil) + item, exists, err := store.GetByKey(testObject.Id) + assert.Nil(t, err) + assert.Equal(t, item, testObject) + assert.True(t, exists) + }, + }) + tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item does not exist", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + item, exists, err := store.GetByKey(testObject.Id) + assert.Nil(t, err) + assert.Equal(t, nil, item) + assert.False(t, exists) + }, + }) + tests = append(tests, testCase{description: "GetByKey with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + _, _, err := store.GetByKey(testObject.Id) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Replace will delete the contents of the store, using instead the +// given list. Store takes ownership of the list, you should not reference +// it after calling this function. +func TestReplace(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "Replace with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) + txC.EXPECT().Commit() + err := store.Replace([]any{testObject}, testObject.Id) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{}, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) + txC.EXPECT().Commit() + err := store.Replace([]any{testObject}, testObject.Id) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with no DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Resync is meaningless in the terms appearing here but has +// meaning in some implementations that have non-trivial +// additional behavior (e.g., DeltaFIFO). +func TestResync(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + tests = append(tests, testCase{description: "Resync shouldn't call the client, panic, or do anything else", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + err := store.Resync() + assert.Nil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +func SetupMockDB(t *testing.T) (*MockDBClient, *MockTXClient) { + dbC := NewMockDBClient(gomock.NewController(t)) // add functionality once store expectation are known + txC := NewMockTXClient(gomock.NewController(t)) + // stmt := NewMockStmt(gomock.NewController()) + txC.EXPECT().Exec(fmt.Sprintf(createTableFmt, "testStoreObject")).Return(nil) + txC.EXPECT().Commit().Return(nil) + dbC.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + + // use stmt mock here + dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + + return dbC, txC +} +func SetupStore(t *testing.T, client *MockDBClient, shouldEncrypt bool) *Store { + store, err := NewStore(testStoreObject{}, testStoreKeyFunc, client, shouldEncrypt, "testStoreObject") + if err != nil { + t.Error(err) + } + return store +} diff --git a/pkg/sqlcache/store/tx_mocks_test.go b/pkg/sqlcache/store/tx_mocks_test.go new file mode 100644 index 00000000..6fd228ff --- /dev/null +++ b/pkg/sqlcache/store/tx_mocks_test.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt +// + +// Package store is a generated GoMock package. +package store + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} From 15020f2fe65b0b7047f0c84191b2c0d104b2a804 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Wed, 15 Jan 2025 23:15:31 -0500 Subject: [PATCH 02/10] Rename import from github.com/rancher/lasso/pkg/cache/sql to github.com/rancher/steve/pkg/sqlcache --- pkg/sqlcache/Readme.md | 4 ++-- pkg/sqlcache/db/client.go | 2 +- pkg/sqlcache/db/client_test.go | 4 ++-- pkg/sqlcache/db/db_mocks_test.go | 6 +++--- pkg/sqlcache/db/transaction/transaction_mocks_test.go | 4 ++-- pkg/sqlcache/db/transaction/transaction_test.go | 2 +- pkg/sqlcache/db/transaction_mocks_test.go | 4 ++-- pkg/sqlcache/informer/db_mocks_test.go | 6 +++--- pkg/sqlcache/informer/factory/db_mocks_test.go | 6 +++--- pkg/sqlcache/informer/factory/factory_mocks_test.go | 8 ++++---- pkg/sqlcache/informer/factory/informer_factory.go | 8 ++++---- pkg/sqlcache/informer/factory/informer_factory_test.go | 8 ++++---- pkg/sqlcache/informer/indexer.go | 4 ++-- pkg/sqlcache/informer/indexer_test.go | 4 ++-- pkg/sqlcache/informer/informer.go | 4 ++-- pkg/sqlcache/informer/informer_mocks_test.go | 6 +++--- pkg/sqlcache/informer/informer_test.go | 6 +++--- pkg/sqlcache/informer/listoption_indexer.go | 4 ++-- pkg/sqlcache/informer/listoption_indexer_test.go | 2 +- pkg/sqlcache/informer/sql_mocks_test.go | 8 ++++---- pkg/sqlcache/informer/store_mocks_test.go | 8 ++++---- pkg/sqlcache/informer/tx_mocks_test.go | 4 ++-- pkg/sqlcache/integration_test.go | 6 +++--- pkg/sqlcache/store/db_mocks_test.go | 6 +++--- pkg/sqlcache/store/store.go | 4 ++-- pkg/sqlcache/store/store_mocks_test.go | 8 ++++---- pkg/sqlcache/store/store_test.go | 8 ++++---- pkg/sqlcache/store/tx_mocks_test.go | 4 ++-- pkg/stores/sqlpartition/listprocessor/processor.go | 4 ++-- pkg/stores/sqlpartition/listprocessor/processor_test.go | 4 ++-- pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go | 4 ++-- pkg/stores/sqlpartition/partition_mocks_test.go | 2 +- pkg/stores/sqlpartition/partitioner.go | 2 +- pkg/stores/sqlpartition/partitioner_test.go | 2 +- pkg/stores/sqlpartition/store.go | 2 +- pkg/stores/sqlpartition/store_test.go | 2 +- pkg/stores/sqlproxy/proxy_mocks_test.go | 6 +++--- pkg/stores/sqlproxy/proxy_store.go | 6 +++--- pkg/stores/sqlproxy/proxy_store_test.go | 8 ++++---- pkg/stores/sqlproxy/sql_informer_mocks_test.go | 8 ++++---- 40 files changed, 99 insertions(+), 99 deletions(-) diff --git a/pkg/sqlcache/Readme.md b/pkg/sqlcache/Readme.md index a0c4ce66..a71e3840 100644 --- a/pkg/sqlcache/Readme.md +++ b/pkg/sqlcache/Readme.md @@ -58,8 +58,8 @@ intended to be used as a way of enforcing RBAC. package main import( "k8s.io/client-go/dynamic" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" ) func main() { diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index 441e2a2a..be5f5b16 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -16,7 +16,7 @@ import ( "sync" "github.com/pkg/errors" - "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" _ "modernc.org/sqlite" ) diff --git a/pkg/sqlcache/db/client_test.go b/pkg/sqlcache/db/client_test.go index 8adf74c3..8b7951f1 100644 --- a/pkg/sqlcache/db/client_test.go +++ b/pkg/sqlcache/db/client_test.go @@ -16,8 +16,8 @@ import ( ) // Mocks for this test are generated with the following command. -//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db Rows,Connection,Encryptor,Decryptor,TXClient -//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx type testStoreObject struct { Id string diff --git a/pkg/sqlcache/db/db_mocks_test.go b/pkg/sqlcache/db/db_mocks_test.go index 55580ee2..54199ba4 100644 --- a/pkg/sqlcache/db/db_mocks_test.go +++ b/pkg/sqlcache/db/db_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db Rows,Connection,Encryptor,Decryptor,TXClient +// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient // // Package db is a generated GoMock package. @@ -14,7 +14,7 @@ import ( sql "database/sql" reflect "reflect" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/db/transaction/transaction_mocks_test.go b/pkg/sqlcache/db/transaction/transaction_mocks_test.go index 3bd82287..0d7fdaa7 100644 --- a/pkg/sqlcache/db/transaction/transaction_mocks_test.go +++ b/pkg/sqlcache/db/transaction/transaction_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt,SQLTx) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx +// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx // // Package transaction is a generated GoMock package. diff --git a/pkg/sqlcache/db/transaction/transaction_test.go b/pkg/sqlcache/db/transaction/transaction_test.go index aada33a7..0ede5d2e 100644 --- a/pkg/sqlcache/db/transaction/transaction_test.go +++ b/pkg/sqlcache/db/transaction/transaction_test.go @@ -9,7 +9,7 @@ import ( "go.uber.org/mock/gomock" ) -//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx +//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx func TestNewClient(t *testing.T) { tx := NewMockSQLTx(gomock.NewController(t)) diff --git a/pkg/sqlcache/db/transaction_mocks_test.go b/pkg/sqlcache/db/transaction_mocks_test.go index 1cc9c874..1cac5caf 100644 --- a/pkg/sqlcache/db/transaction_mocks_test.go +++ b/pkg/sqlcache/db/transaction_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt,SQLTx) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt,SQLTx +// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx // // Package db is a generated GoMock package. diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go index 3731e3d0..7d2c81ce 100644 --- a/pkg/sqlcache/informer/db_mocks_test.go +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: TXClient,Rows) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows +// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows // // Package informer is a generated GoMock package. @@ -13,7 +13,7 @@ import ( sql "database/sql" reflect "reflect" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go index fd5fa071..9ac55bb3 100644 --- a/pkg/sqlcache/informer/factory/db_mocks_test.go +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: TXClient) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient +// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient // // Package factory is a generated GoMock package. @@ -13,7 +13,7 @@ import ( sql "database/sql" reflect "reflect" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/informer/factory/factory_mocks_test.go b/pkg/sqlcache/informer/factory/factory_mocks_test.go index fa5d4739..a7adab6a 100644 --- a/pkg/sqlcache/informer/factory/factory_mocks_test.go +++ b/pkg/sqlcache/informer/factory/factory_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/informer/factory (interfaces: DBClient) +// Source: github.com/rancher/steve/pkg/sqlcache/informer/factory (interfaces: DBClient) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer/factory DBClient +// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient // // Package factory is a generated GoMock package. @@ -14,8 +14,8 @@ import ( sql "database/sql" reflect "reflect" - db "github.com/rancher/lasso/pkg/cache/sql/db" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go index 5ce85347..61559771 100644 --- a/pkg/sqlcache/informer/factory/informer_factory.go +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -9,10 +9,10 @@ import ( "sync" "time" - "github.com/rancher/lasso/pkg/cache/sql/db" - "github.com/rancher/lasso/pkg/cache/sql/encryption" - "github.com/rancher/lasso/pkg/cache/sql/informer" - sqlStore "github.com/rancher/lasso/pkg/cache/sql/store" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/encryption" + "github.com/rancher/steve/pkg/sqlcache/informer" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" "github.com/rancher/lasso/pkg/log" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime/schema" diff --git a/pkg/sqlcache/informer/factory/informer_factory_test.go b/pkg/sqlcache/informer/factory/informer_factory_test.go index 0f77435b..e3b96562 100644 --- a/pkg/sqlcache/informer/factory/informer_factory_test.go +++ b/pkg/sqlcache/informer/factory/informer_factory_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" - "github.com/rancher/lasso/pkg/cache/sql/informer" + "github.com/rancher/steve/pkg/sqlcache/informer" - sqlStore "github.com/rancher/lasso/pkg/cache/sql/store" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/runtime/schema" @@ -15,8 +15,8 @@ import ( "k8s.io/client-go/tools/cache" ) -//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer/factory DBClient -//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient //go:generate mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface //go:generate mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go index 14305339..7ed4451b 100644 --- a/pkg/sqlcache/informer/indexer.go +++ b/pkg/sqlcache/informer/indexer.go @@ -8,8 +8,8 @@ import ( "strings" "sync" - "github.com/rancher/lasso/pkg/cache/sql/db" - "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" "k8s.io/client-go/tools/cache" ) diff --git a/pkg/sqlcache/informer/indexer_test.go b/pkg/sqlcache/informer/indexer_test.go index a861efb6..4118118c 100644 --- a/pkg/sqlcache/informer/indexer_test.go +++ b/pkg/sqlcache/informer/indexer_test.go @@ -18,8 +18,8 @@ import ( "k8s.io/client-go/tools/cache" ) -//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer Store -//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows type testStoreObject struct { Id string diff --git a/pkg/sqlcache/informer/informer.go b/pkg/sqlcache/informer/informer.go index b893713a..a74c7029 100644 --- a/pkg/sqlcache/informer/informer.go +++ b/pkg/sqlcache/informer/informer.go @@ -8,8 +8,8 @@ import ( "context" "time" - "github.com/rancher/lasso/pkg/cache/sql/partition" - sqlStore "github.com/rancher/lasso/pkg/cache/sql/store" + "github.com/rancher/steve/pkg/sqlcache/partition" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" diff --git a/pkg/sqlcache/informer/informer_mocks_test.go b/pkg/sqlcache/informer/informer_mocks_test.go index ae6bf4b6..9eff0612 100644 --- a/pkg/sqlcache/informer/informer_mocks_test.go +++ b/pkg/sqlcache/informer/informer_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: ByOptionsLister) +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +// mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister // // Package informer is a generated GoMock package. @@ -13,7 +13,7 @@ import ( context "context" reflect "reflect" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/pkg/sqlcache/informer/informer_test.go b/pkg/sqlcache/informer/informer_test.go index 5199bdc8..5337ee8e 100644 --- a/pkg/sqlcache/informer/informer_test.go +++ b/pkg/sqlcache/informer/informer_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -18,9 +18,9 @@ import ( "k8s.io/client-go/tools/cache" ) -//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister //go:generate mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface -//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient func TestNewInformer(t *testing.T) { type testCase struct { diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 3290cf70..86bc62f8 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -15,8 +15,8 @@ import ( "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/client-go/tools/cache" - "github.com/rancher/lasso/pkg/cache/sql/db" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/partition" ) // ListOptionIndexer extends Indexer by allowing queries based on ListOption diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index fe7a038f..5352cebd 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -15,7 +15,7 @@ import ( "strings" "testing" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go index c172320f..c269b01b 100644 --- a/pkg/sqlcache/informer/sql_mocks_test.go +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: Store) +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: Store) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer Store +// mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store // // Package informer is a generated GoMock package. @@ -14,8 +14,8 @@ import ( sql "database/sql" reflect "reflect" - db "github.com/rancher/lasso/pkg/cache/sql/db" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/informer/store_mocks_test.go b/pkg/sqlcache/informer/store_mocks_test.go index fcb1a90d..c1c7d426 100644 --- a/pkg/sqlcache/informer/store_mocks_test.go +++ b/pkg/sqlcache/informer/store_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/store (interfaces: DBClient) +// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient +// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient // // Package informer is a generated GoMock package. @@ -14,8 +14,8 @@ import ( sql "database/sql" reflect "reflect" - db "github.com/rancher/lasso/pkg/cache/sql/db" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/informer/tx_mocks_test.go b/pkg/sqlcache/informer/tx_mocks_test.go index e482df9b..9383411d 100644 --- a/pkg/sqlcache/informer/tx_mocks_test.go +++ b/pkg/sqlcache/informer/tx_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt +// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt // // Package informer is a generated GoMock package. diff --git a/pkg/sqlcache/integration_test.go b/pkg/sqlcache/integration_test.go index 3c752967..58715918 100644 --- a/pkg/sqlcache/integration_test.go +++ b/pkg/sqlcache/integration_test.go @@ -18,9 +18,9 @@ import ( "k8s.io/client-go/tools/cache" "sigs.k8s.io/controller-runtime/pkg/envtest" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" ) const testNamespace = "sql-test" diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go index 489a1c1c..75f70b6e 100644 --- a/pkg/sqlcache/store/db_mocks_test.go +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db (interfaces: TXClient,Rows) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows +// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows // // Package store is a generated GoMock package. @@ -13,7 +13,7 @@ import ( sql "database/sql" reflect "reflect" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index a89d57bf..a4af1b50 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -9,8 +9,8 @@ import ( "fmt" "reflect" - "github.com/rancher/lasso/pkg/cache/sql/db" - "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" "github.com/rancher/lasso/pkg/log" "k8s.io/client-go/tools/cache" _ "modernc.org/sqlite" diff --git a/pkg/sqlcache/store/store_mocks_test.go b/pkg/sqlcache/store/store_mocks_test.go index 1ae793c8..d30df82b 100644 --- a/pkg/sqlcache/store/store_mocks_test.go +++ b/pkg/sqlcache/store/store_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/store (interfaces: DBClient) +// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient +// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient // // Package store is a generated GoMock package. @@ -14,8 +14,8 @@ import ( sql "database/sql" reflect "reflect" - db "github.com/rancher/lasso/pkg/cache/sql/db" - transaction "github.com/rancher/lasso/pkg/cache/sql/db/transaction" + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go index dba11367..1d4e2613 100644 --- a/pkg/sqlcache/store/store_test.go +++ b/pkg/sqlcache/store/store_test.go @@ -7,9 +7,9 @@ Adapted from client-go, Copyright 2014 The Kubernetes Authors. package store // Mocks for this test are generated with the following command. -//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/store DBClient -//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db TXClient,Rows -//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt import ( "context" @@ -18,7 +18,7 @@ import ( "reflect" "testing" - "github.com/rancher/lasso/pkg/cache/sql/db" + "github.com/rancher/steve/pkg/sqlcache/db" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" diff --git a/pkg/sqlcache/store/tx_mocks_test.go b/pkg/sqlcache/store/tx_mocks_test.go index 6fd228ff..0c05ab7f 100644 --- a/pkg/sqlcache/store/tx_mocks_test.go +++ b/pkg/sqlcache/store/tx_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt +// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt // // Package store is a generated GoMock package. diff --git a/pkg/stores/sqlpartition/listprocessor/processor.go b/pkg/stores/sqlpartition/listprocessor/processor.go index 0ea15788..d9f5e02e 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor.go +++ b/pkg/stores/sqlpartition/listprocessor/processor.go @@ -10,8 +10,8 @@ import ( "github.com/rancher/apiserver/pkg/apierror" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/schemas/validation" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/pkg/stores/sqlpartition/listprocessor/processor_test.go b/pkg/stores/sqlpartition/listprocessor/processor_test.go index 08b6ea6a..80cfb0b0 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor_test.go +++ b/pkg/stores/sqlpartition/listprocessor/processor_test.go @@ -8,8 +8,8 @@ import ( "testing" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go b/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go index 693261a7..06598043 100644 --- a/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go +++ b/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/pkg/stores/sqlpartition/partition_mocks_test.go b/pkg/stores/sqlpartition/partition_mocks_test.go index ea72b8c1..687b9006 100644 --- a/pkg/stores/sqlpartition/partition_mocks_test.go +++ b/pkg/stores/sqlpartition/partition_mocks_test.go @@ -13,7 +13,7 @@ import ( reflect "reflect" types "github.com/rancher/apiserver/pkg/types" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" watch "k8s.io/apimachinery/pkg/watch" diff --git a/pkg/stores/sqlpartition/partitioner.go b/pkg/stores/sqlpartition/partitioner.go index ffee38df..6a225b5e 100644 --- a/pkg/stores/sqlpartition/partitioner.go +++ b/pkg/stores/sqlpartition/partitioner.go @@ -5,7 +5,7 @@ import ( "sort" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/attributes" "github.com/rancher/wrangler/v3/pkg/kv" diff --git a/pkg/stores/sqlpartition/partitioner_test.go b/pkg/stores/sqlpartition/partitioner_test.go index caba1897..cdc93ff1 100644 --- a/pkg/stores/sqlpartition/partitioner_test.go +++ b/pkg/stores/sqlpartition/partitioner_test.go @@ -6,7 +6,7 @@ import ( "go.uber.org/mock/gomock" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/stretchr/testify/assert" diff --git a/pkg/stores/sqlpartition/store.go b/pkg/stores/sqlpartition/store.go index 145d5e72..30cc8e69 100644 --- a/pkg/stores/sqlpartition/store.go +++ b/pkg/stores/sqlpartition/store.go @@ -7,7 +7,7 @@ import ( "context" "github.com/rancher/apiserver/pkg/types" - lassopartition "github.com/rancher/lasso/pkg/cache/sql/partition" + lassopartition "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/stores/partition" ) diff --git a/pkg/stores/sqlpartition/store_test.go b/pkg/stores/sqlpartition/store_test.go index d1f11be7..d98a8188 100644 --- a/pkg/stores/sqlpartition/store_test.go +++ b/pkg/stores/sqlpartition/store_test.go @@ -15,7 +15,7 @@ import ( "go.uber.org/mock/gomock" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/stores/sqlproxy" "github.com/rancher/wrangler/v3/pkg/generic" diff --git a/pkg/stores/sqlproxy/proxy_mocks_test.go b/pkg/stores/sqlproxy/proxy_mocks_test.go index a5559d47..45ccd418 100644 --- a/pkg/stores/sqlproxy/proxy_mocks_test.go +++ b/pkg/stores/sqlproxy/proxy_mocks_test.go @@ -14,9 +14,9 @@ import ( reflect "reflect" types "github.com/rancher/apiserver/pkg/types" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - factory "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + factory "github.com/rancher/steve/pkg/sqlcache/informer/factory" + partition "github.com/rancher/steve/pkg/sqlcache/partition" summary "github.com/rancher/wrangler/v3/pkg/summary" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlproxy/proxy_store.go b/pkg/stores/sqlproxy/proxy_store.go index f6d1e34b..7d57774c 100644 --- a/pkg/stores/sqlproxy/proxy_store.go +++ b/pkg/stores/sqlproxy/proxy_store.go @@ -32,9 +32,9 @@ import ( "github.com/rancher/apiserver/pkg/apierror" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/data" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/rancher/wrangler/v3/pkg/schemas/validation" diff --git a/pkg/stores/sqlproxy/proxy_store_test.go b/pkg/stores/sqlproxy/proxy_store_test.go index bdbbbaab..8690eea3 100644 --- a/pkg/stores/sqlproxy/proxy_store_test.go +++ b/pkg/stores/sqlproxy/proxy_store_test.go @@ -12,9 +12,9 @@ import ( "github.com/rancher/wrangler/v3/pkg/schemas/validation" apierrors "k8s.io/apimachinery/pkg/api/errors" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/attributes" "github.com/rancher/steve/pkg/resources/common" "github.com/rancher/steve/pkg/stores/sqlpartition/listprocessor" @@ -42,7 +42,7 @@ import ( ) //go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./proxy_mocks_test.go github.com/rancher/steve/pkg/stores/sqlproxy Cache,ClientGetter,CacheFactory,SchemaColumnSetter,RelationshipNotifier,TransformBuilder -//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister //go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface var c *watch.FakeWatcher diff --git a/pkg/stores/sqlproxy/sql_informer_mocks_test.go b/pkg/stores/sqlproxy/sql_informer_mocks_test.go index e8f5358c..125f2192 100644 --- a/pkg/stores/sqlproxy/sql_informer_mocks_test.go +++ b/pkg/stores/sqlproxy/sql_informer_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: ByOptionsLister) +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister // // Package sqlproxy is a generated GoMock package. @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) From 5b24050cf20067e05f2e4002a807c5f88b00e9d5 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Wed, 15 Jan 2025 23:46:47 -0500 Subject: [PATCH 03/10] Fix filter.Match -> filter.Matches --- .../sqlpartition/listprocessor/processor.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/stores/sqlpartition/listprocessor/processor.go b/pkg/stores/sqlpartition/listprocessor/processor.go index d9f5e02e..11457089 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor.go +++ b/pkg/stores/sqlpartition/listprocessor/processor.go @@ -79,7 +79,7 @@ func ParseQuery(apiOp *types.APIRequest, namespaceCache Cache) (informer.ListOpt } usePartialMatch := !(strings.HasPrefix(filter[1], `'`) && strings.HasSuffix(filter[1], `'`)) value := strings.TrimSuffix(strings.TrimPrefix(filter[1], "'"), "'") - orFilter.Filters = append(orFilter.Filters, informer.Filter{Field: strings.Split(filter[0], "."), Match: value, Op: op, Partial: usePartialMatch}) + orFilter.Filters = append(orFilter.Filters, informer.Filter{Field: strings.Split(filter[0], "."), Matches: []string{value}, Op: op, Partial: usePartialMatch}) } filterOpts = append(filterOpts, orFilter) } @@ -170,14 +170,14 @@ func parseNamespaceOrProjectFilters(ctx context.Context, projOrNS string, op inf { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: pn, - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{pn}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: pn, - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{pn}, + Op: informer.Eq, }, }, }, @@ -189,7 +189,7 @@ func parseNamespaceOrProjectFilters(ctx context.Context, projOrNS string, op inf for _, item := range uList.Items { filters = append(filters, informer.Filter{ Field: []string{"metadata", "namespace"}, - Match: item.GetName(), + Matches: []string{item.GetName()}, Op: op, Partial: false, }) From 42ee2c82b8220d7c14a7e28b1e59ef35231d9906 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Wed, 15 Jan 2025 23:51:00 -0500 Subject: [PATCH 04/10] go mod tidy --- go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 739f84a9..36c7be58 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,8 @@ require ( k8s.io/kube-aggregator v0.31.1 k8s.io/kube-openapi v0.0.0-20240411171206-dc4e619f62f3 k8s.io/kubernetes v1.31.1 + k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 + modernc.org/sqlite v1.29.10 sigs.k8s.io/controller-runtime v0.19.0 ) @@ -137,12 +139,10 @@ require ( k8s.io/component-base v0.31.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kms v0.31.1 // indirect - k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.49.3 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect - modernc.org/sqlite v1.29.10 // indirect modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.30.3 // indirect From 447a991765d7e9741bc0fe8e3a768437058953d0 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Wed, 15 Jan 2025 23:53:29 -0500 Subject: [PATCH 05/10] Fix lint errors --- pkg/sqlcache/db/transaction/transaction.go | 2 +- pkg/sqlcache/store/store.go | 2 +- pkg/stores/sqlpartition/partitioner.go | 2 +- pkg/stores/sqlpartition/store.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/sqlcache/db/transaction/transaction.go b/pkg/sqlcache/db/transaction/transaction.go index 83609143..68b476ea 100644 --- a/pkg/sqlcache/db/transaction/transaction.go +++ b/pkg/sqlcache/db/transaction/transaction.go @@ -7,9 +7,9 @@ package transaction import ( "context" "database/sql" - "github.com/sirupsen/logrus" "github.com/pkg/errors" + "github.com/sirupsen/logrus" ) // Client provides a way to interact with the underlying sql transaction. diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index a4af1b50..2fa6cd85 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -9,9 +9,9 @@ import ( "fmt" "reflect" + "github.com/rancher/lasso/pkg/log" "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/db/transaction" - "github.com/rancher/lasso/pkg/log" "k8s.io/client-go/tools/cache" _ "modernc.org/sqlite" ) diff --git a/pkg/stores/sqlpartition/partitioner.go b/pkg/stores/sqlpartition/partitioner.go index 6a225b5e..b3b74f9b 100644 --- a/pkg/stores/sqlpartition/partitioner.go +++ b/pkg/stores/sqlpartition/partitioner.go @@ -5,9 +5,9 @@ import ( "sort" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/attributes" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/kv" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlpartition/store.go b/pkg/stores/sqlpartition/store.go index 30cc8e69..f14953bb 100644 --- a/pkg/stores/sqlpartition/store.go +++ b/pkg/stores/sqlpartition/store.go @@ -7,8 +7,8 @@ import ( "context" "github.com/rancher/apiserver/pkg/types" - lassopartition "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" + lassopartition "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/stores/partition" ) From dd69bdd2153133e0197f8649d6bf3a2cb66bca92 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Wed, 15 Jan 2025 23:57:05 -0500 Subject: [PATCH 06/10] Remove lasso SQL cache mentions --- pkg/server/server.go | 2 +- pkg/sqlcache/Readme.md | 4 ++-- pkg/sqlcache/db/client.go | 2 +- pkg/stores/sqlpartition/store.go | 4 ++-- pkg/stores/sqlproxy/proxy_store.go | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index 4a7829f0..a9711e77 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -82,7 +82,7 @@ type Options struct { AggregationSecretName string ClusterRegistry string ServerVersion string - // SQLCache enables the SQLite-based lasso caching mechanism + // SQLCache enables the SQLite-based caching mechanism SQLCache bool // ExtensionAPIServer enables an extension API server that will be served diff --git a/pkg/sqlcache/Readme.md b/pkg/sqlcache/Readme.md index a71e3840..99da82a2 100644 --- a/pkg/sqlcache/Readme.md +++ b/pkg/sqlcache/Readme.md @@ -105,7 +105,7 @@ that contains the functionality needed to conform to cache.Indexer. ### SQLite Driver There are multiple SQLite drivers that this package could have used. One of the most, if not the most, popular SQLite golang drivers is [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3). This driver is not being used because it requires enabling -the cgo option when compiling and at the moment lasso's main consumer, rancher, does not compile with cgo. We did not want +the cgo option when compiling and at the moment steve's main consumer, rancher, does not compile with cgo. We did not want the SQL informer to be the sole driver in switching to using cgo. Instead, modernc's driver which is in pure golang. Side-by-side comparisons can be found indicating the cgo version is, as expected, more performant. If in the future it is deemed worthwhile then the driver can be easily switched by replacing the empty import in `pkg/cache/sql/store` from `_ "modernc.org/sqlite"` to `_ "github.com/mattn/go-sqlite3"`. @@ -117,7 +117,7 @@ connections attached to a sql.Connection. `database/sql` manages this connection application only need one sql.Connection, although sometimes application use two: one for writes, the other for reads. To read more about the `sql` package's connection pooling read [Managing connections](https://go.dev/doc/database/manage-connections). -The use of connection pooling and the fact that lasso potentially has many go routines accessing the same connection pool, +The use of connection pooling and the fact that steve potentially has many go routines accessing the same connection pool, means we have to be careful with writes. Exclusively using sql transaction to write helps ensure safety. To read more about sql transactions read SQLite's [Transaction docs](https://www.sqlite.org/lang_transaction.html). diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index be5f5b16..03c061a2 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -21,7 +21,7 @@ import ( ) const ( - // InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running lasso + // InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve InformerObjectCacheDBPath = "informer_object_cache.db" informerObjectCachePerms fs.FileMode = 0o600 diff --git a/pkg/stores/sqlpartition/store.go b/pkg/stores/sqlpartition/store.go index f14953bb..f4ebb325 100644 --- a/pkg/stores/sqlpartition/store.go +++ b/pkg/stores/sqlpartition/store.go @@ -8,13 +8,13 @@ import ( "github.com/rancher/apiserver/pkg/types" "github.com/rancher/steve/pkg/accesscontrol" - lassopartition "github.com/rancher/steve/pkg/sqlcache/partition" + cachepartition "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/stores/partition" ) // Partitioner is an interface for interacting with partitions. type Partitioner interface { - All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]lassopartition.Partition, error) + All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]cachepartition.Partition, error) Store() UnstructuredStore } diff --git a/pkg/stores/sqlproxy/proxy_store.go b/pkg/stores/sqlproxy/proxy_store.go index 7d57774c..5977cc83 100644 --- a/pkg/stores/sqlproxy/proxy_store.go +++ b/pkg/stores/sqlproxy/proxy_store.go @@ -333,7 +333,7 @@ func gvkKey(group, version, kind string) string { return group + "_" + version + "_" + kind } -// getFieldsFromSchema converts object field names from types.APISchema's format into lasso's +// getFieldsFromSchema converts object field names from types.APISchema's format into steve's // cache.sql.informer's slice format (e.g. "metadata.resourceVersion" is ["metadata", "resourceVersion"]) func getFieldsFromSchema(schema *types.APISchema) [][]string { var fields [][]string From 1b23ab0d18da5a824bc12d79b5e21f118ddb5c8b Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Thu, 16 Jan 2025 00:04:18 -0500 Subject: [PATCH 07/10] Fix more CI lint errors --- pkg/sqlcache/db/client.go | 1 + pkg/sqlcache/informer/factory/informer_factory.go | 2 +- pkg/sqlcache/informer/listoption_indexer.go | 4 ++-- pkg/sqlcache/store/store.go | 1 + pkg/stores/sqlproxy/proxy_store.go | 2 +- 5 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index 03c061a2..2ecc4b59 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -17,6 +17,7 @@ import ( "github.com/pkg/errors" "github.com/rancher/steve/pkg/sqlcache/db/transaction" + // needed for drivers _ "modernc.org/sqlite" ) diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go index 61559771..a8f3d266 100644 --- a/pkg/sqlcache/informer/factory/informer_factory.go +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -9,11 +9,11 @@ import ( "sync" "time" + "github.com/rancher/lasso/pkg/log" "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/encryption" "github.com/rancher/steve/pkg/sqlcache/informer" sqlStore "github.com/rancher/steve/pkg/sqlcache/store" - "github.com/rancher/lasso/pkg/log" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/wait" diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 86bc62f8..93ce5f08 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -42,7 +42,7 @@ var ( defaultIndexNamespaced = "metadata.namespace" subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[a-zA-Z./]+])|(\[[0-9]+])`) - InvalidColumnErr = errors.New("supplied column is invalid") + ErrInvalidColumn = errors.New("supplied column is invalid") ) const ( @@ -528,7 +528,7 @@ func (l *ListOptionIndexer) validateColumn(column string) error { return nil } } - return fmt.Errorf("column is invalid [%s]: %w", column, InvalidColumnErr) + return fmt.Errorf("column is invalid [%s]: %w", column, ErrInvalidColumn) } // buildORClause creates an SQLite compatible query that ORs conditions built from passed filters diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index 2fa6cd85..0368b60e 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -13,6 +13,7 @@ import ( "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/db/transaction" "k8s.io/client-go/tools/cache" + // needed for drivers _ "modernc.org/sqlite" ) diff --git a/pkg/stores/sqlproxy/proxy_store.go b/pkg/stores/sqlproxy/proxy_store.go index 5977cc83..878405a0 100644 --- a/pkg/stores/sqlproxy/proxy_store.go +++ b/pkg/stores/sqlproxy/proxy_store.go @@ -757,7 +757,7 @@ func (s *Store) ListByPartitions(apiOp *types.APIRequest, schema *types.APISchem list, total, continueToken, err := inf.ListByOptions(apiOp.Context(), opts, partitions, apiOp.Namespace) if err != nil { - if errors.Is(err, informer.InvalidColumnErr) { + if errors.Is(err, informer.ErrInvalidColumn) { return nil, 0, "", apierror.NewAPIError(validation.InvalidBodyContent, err.Error()) } return nil, 0, "", err From 2e839830069d9585a92235ef888c7a78a24a057c Mon Sep 17 00:00:00 2001 From: Silvio Moioli Date: Thu, 16 Jan 2025 09:23:53 +0100 Subject: [PATCH 08/10] fix goimports Signed-off-by: Silvio Moioli --- pkg/sqlcache/db/client.go | 1 + pkg/sqlcache/store/store.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index 2ecc4b59..dffbb6dd 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -17,6 +17,7 @@ import ( "github.com/pkg/errors" "github.com/rancher/steve/pkg/sqlcache/db/transaction" + // needed for drivers _ "modernc.org/sqlite" ) diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index 0368b60e..a228ee86 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -13,6 +13,7 @@ import ( "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/db/transaction" "k8s.io/client-go/tools/cache" + // needed for drivers _ "modernc.org/sqlite" ) From 797f66d33a7f9acd20188f2067fe0d7d8324d4f4 Mon Sep 17 00:00:00 2001 From: Silvio Moioli Date: Thu, 16 Jan 2025 09:24:09 +0100 Subject: [PATCH 09/10] fix tests (Match -> Matches) Signed-off-by: Silvio Moioli --- .../listprocessor/processor_test.go | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/pkg/stores/sqlpartition/listprocessor/processor_test.go b/pkg/stores/sqlpartition/listprocessor/processor_test.go index 80cfb0b0..6c67bd18 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor_test.go +++ b/pkg/stores/sqlpartition/listprocessor/processor_test.go @@ -60,7 +60,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"metadata", "namespace"}, - Match: "ns1", + Matches: []string{"ns1"}, Op: "", Partial: false, }, @@ -89,14 +89,14 @@ func TestParseQuery(t *testing.T) { { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, }, }, @@ -120,7 +120,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"metadata", "namespace"}, - Match: "ns1", + Matches: []string{"ns1"}, Op: "", Partial: false, }, @@ -139,14 +139,14 @@ func TestParseQuery(t *testing.T) { { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, }, }, @@ -170,7 +170,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"metadata", "namespace"}, - Match: "ns1", + Matches: []string{"ns1"}, Op: "", Partial: false, }, @@ -192,14 +192,14 @@ func TestParseQuery(t *testing.T) { { Filters: []informer.Filter{ { - Field: []string{"metadata", "name"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "name"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, { - Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, - Match: "somethin", - Op: informer.Eq, + Field: []string{"metadata", "labels[field.cattle.io/projectId]"}, + Matches: []string{"somethin"}, + Op: informer.Eq, }, }, }, @@ -222,7 +222,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: true, }, @@ -251,7 +251,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: false, }, @@ -280,7 +280,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: true, }, @@ -290,7 +290,7 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"b"}, - Match: "d", + Matches: []string{"d"}, Op: "", Partial: true, }, @@ -320,13 +320,13 @@ func TestParseQuery(t *testing.T) { Filters: []informer.Filter{ { Field: []string{"a"}, - Match: "c", + Matches: []string{"c"}, Op: "", Partial: true, }, { Field: []string{"b"}, - Match: "d", + Matches: []string{"d"}, Op: "", Partial: true, }, From 7e8774226c8a1c2ebdcbde711483e6e8c0dc26a0 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Thu, 16 Jan 2025 13:09:09 -0500 Subject: [PATCH 10/10] Fix Sort order --- .../sqlpartition/listprocessor/processor.go | 8 ++++---- .../listprocessor/processor_test.go | 19 ++++++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pkg/stores/sqlpartition/listprocessor/processor.go b/pkg/stores/sqlpartition/listprocessor/processor.go index 11457089..ee2f49b3 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor.go +++ b/pkg/stores/sqlpartition/listprocessor/processor.go @@ -91,20 +91,20 @@ func ParseQuery(apiOp *types.APIRequest, namespaceCache Cache) (informer.ListOpt sortParts := strings.SplitN(sortKeys, ",", 2) primaryField := sortParts[0] if primaryField != "" && primaryField[0] == '-' { - sortOpts.PrimaryOrder = informer.DESC + sortOpts.Orders = append(sortOpts.Orders, informer.DESC) primaryField = primaryField[1:] } if primaryField != "" { - sortOpts.PrimaryField = strings.Split(primaryField, ".") + sortOpts.Fields = append(sortOpts.Fields, strings.Split(primaryField, ".")) } if len(sortParts) > 1 { secondaryField := sortParts[1] if secondaryField != "" && secondaryField[0] == '-' { - sortOpts.SecondaryOrder = informer.DESC + sortOpts.Orders = append(sortOpts.Orders, informer.DESC) secondaryField = secondaryField[1:] } if secondaryField != "" { - sortOpts.SecondaryField = strings.Split(secondaryField, ".") + sortOpts.Fields = append(sortOpts.Fields, strings.Split(secondaryField, ".")) } } } diff --git a/pkg/stores/sqlpartition/listprocessor/processor_test.go b/pkg/stores/sqlpartition/listprocessor/processor_test.go index 6c67bd18..0aa3207d 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor_test.go +++ b/pkg/stores/sqlpartition/listprocessor/processor_test.go @@ -352,7 +352,9 @@ func TestParseQuery(t *testing.T) { expectedLO: informer.ListOptions{ ChunkSize: defaultLimit, Sort: informer.Sort{ - PrimaryField: []string{"metadata", "name"}, + Fields: [][]string{ + {"metadata", "name"}, + }, }, Filters: make([]informer.OrFilter, 0), Pagination: informer.Pagination{ @@ -374,8 +376,8 @@ func TestParseQuery(t *testing.T) { expectedLO: informer.ListOptions{ ChunkSize: defaultLimit, Sort: informer.Sort{ - PrimaryField: []string{"metadata", "name"}, - PrimaryOrder: informer.DESC, + Fields: [][]string{{"metadata", "name"}}, + Orders: []informer.SortOrder{informer.DESC}, }, Filters: make([]informer.OrFilter, 0), Pagination: informer.Pagination{ @@ -397,10 +399,13 @@ func TestParseQuery(t *testing.T) { expectedLO: informer.ListOptions{ ChunkSize: defaultLimit, Sort: informer.Sort{ - PrimaryField: []string{"metadata", "name"}, - PrimaryOrder: informer.DESC, - SecondaryField: []string{"spec", "something"}, - SecondaryOrder: informer.ASC, + Fields: [][]string{ + {"metadata", "name"}, + {"spec", "something"}, + }, + Orders: []informer.SortOrder{ + informer.DESC, + }, }, Filters: make([]informer.OrFilter, 0), Pagination: informer.Pagination{