From 06461d7eda85416e521070335ce493f5094ea91c Mon Sep 17 00:00:00 2001 From: arielsepton <64063409+arielsepton@users.noreply.github.com> Date: Thu, 11 Jan 2024 12:11:37 +0200 Subject: [PATCH] fix MSSQL database user bug (#168) * fix mssql user bug Signed-off-by: arielsepton * revert test file Signed-off-by: arielsepton * fix test Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com> * gofmt -s -w Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com> * remove redundent print and edd better error messages Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com> --------- Signed-off-by: arielsepton Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com> --- pkg/controller/mssql/user/reconciler.go | 58 +++++++++++++++++--- pkg/controller/mssql/user/reconciler_test.go | 25 ++++++++- 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/pkg/controller/mssql/user/reconciler.go b/pkg/controller/mssql/user/reconciler.go index b6f66cdc..ed8fa35f 100644 --- a/pkg/controller/mssql/user/reconciler.go +++ b/pkg/controller/mssql/user/reconciler.go @@ -47,10 +47,15 @@ const ( errNoSecretRef = "ProviderConfig does not reference a credentials Secret" errGetSecret = "cannot get credentials Secret" - errNotUser = "managed resource is not a User custom resource" - errSelectUser = "cannot select user" - errCreateUser = "cannot create user" - errDropUser = "cannot drop user" + errNotUser = "managed resource is not a User custom resource" + errSelectUser = "cannot select user" + errCreateUser = "cannot create user %s" + errCreateLogin = "cannot create login %s" + errDropUser = "error dropping user %s" + errDropLogin = "error dropping login %s" + errCannotGetLogins = "cannot get current logins" + errCannotKillLoginSession = "error killing session %d for login %s" + errUpdateUser = "cannot update user" errGetPasswordSecretFailed = "cannot get password secret" @@ -173,11 +178,19 @@ func (c *external) Create(ctx context.Context, mg resource.Managed) (managed.Ext return managed.ExternalCreation{}, err } } - query := fmt.Sprintf("CREATE USER %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw)) + + loginQuery := fmt.Sprintf("CREATE LOGIN %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw)) + if err := c.db.Exec(ctx, xsql.Query{ + String: loginQuery, + }); err != nil { + return managed.ExternalCreation{}, errors.Wrapf(err, errCreateLogin, meta.GetExternalName(cr)) + } + + userQuery := fmt.Sprintf("CREATE USER %s FOR LOGIN %s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteIdentifier(meta.GetExternalName(cr))) if err := c.db.Exec(ctx, xsql.Query{ - String: query, + String: userQuery, }); err != nil { - return managed.ExternalCreation{}, errors.Wrap(err, errCreateUser) + return managed.ExternalCreation{}, errors.Wrapf(err, errCreateUser, meta.GetExternalName(cr)) } return managed.ExternalCreation{ @@ -197,7 +210,7 @@ func (c *external) Update(ctx context.Context, mg resource.Managed) (managed.Ext } if changed { - query := fmt.Sprintf("ALTER USER %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw)) + query := fmt.Sprintf("ALTER LOGIN %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw)) if err := c.db.Exec(ctx, xsql.Query{ String: query, }); err != nil { @@ -217,10 +230,37 @@ func (c *external) Delete(ctx context.Context, mg resource.Managed) error { return errors.New(errNotUser) } + query := fmt.Sprintf("SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = %s", mssql.QuoteValue(meta.GetExternalName(cr))) + rows, err := c.db.Query(ctx, xsql.Query{String: query}) + if err != nil { + return errors.Wrap(err, errCannotGetLogins) + } + defer rows.Close() //nolint:errcheck + + for rows.Next() { + var sessionID int + if err := rows.Scan(&sessionID); err != nil { + return errors.Wrap(err, errCannotGetLogins) + } + if err := c.db.Exec(ctx, xsql.Query{String: fmt.Sprintf("KILL %d", sessionID)}); err != nil { + return errors.Wrapf(err, errCannotKillLoginSession, sessionID, meta.GetExternalName(cr)) + } + } + if err := rows.Err(); err != nil { + return errors.Wrap(err, errCannotGetLogins) + } + if err := c.db.Exec(ctx, xsql.Query{ String: fmt.Sprintf("DROP USER IF EXISTS %s", mssql.QuoteIdentifier(meta.GetExternalName(cr))), }); err != nil { - return errors.Wrap(err, errDropUser) + return errors.Wrapf(err, errDropUser, meta.GetExternalName(cr)) } + + if err := c.db.Exec(ctx, xsql.Query{ + String: fmt.Sprintf("DROP LOGIN %s", mssql.QuoteIdentifier(meta.GetExternalName(cr))), + }); err != nil { + return errors.Wrapf(err, errDropLogin, meta.GetExternalName(cr)) + } + return nil } diff --git a/pkg/controller/mssql/user/reconciler_test.go b/pkg/controller/mssql/user/reconciler_test.go index 1fe3ff4c..a0a3a74c 100644 --- a/pkg/controller/mssql/user/reconciler_test.go +++ b/pkg/controller/mssql/user/reconciler_test.go @@ -21,6 +21,7 @@ import ( "database/sql" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/crossplane-contrib/provider-sql/apis/mssql/v1alpha1" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -42,6 +43,7 @@ type mockDB struct { MockExec func(ctx context.Context, q xsql.Query) error MockExecTx func(ctx context.Context, ql []xsql.Query) error MockScan func(ctx context.Context, q xsql.Query, dest ...interface{}) error + MockQuery func(ctx context.Context, q xsql.Query) (*sql.Rows, error) } func (m mockDB) Exec(ctx context.Context, q xsql.Query) error { @@ -54,7 +56,7 @@ func (m mockDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) err return m.MockScan(ctx, q, dest...) } func (m mockDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) { - return &sql.Rows{}, nil + return m.MockQuery(ctx, q) } func (m mockDB) GetConnectionDetails(username, password string) managed.ConnectionDetails { return managed.ConnectionDetails{ @@ -65,6 +67,17 @@ func (m mockDB) GetConnectionDetails(username, password string) managed.Connecti } } +func mockRowsToSQLRows(mockRows *sqlmock.Rows) *sql.Rows { + db, mock, _ := sqlmock.New() + mock.ExpectQuery("select").WillReturnRows(mockRows) + rows, err := db.Query("select") + if err != nil { + println("%v", err) + return nil + } + return rows +} + func TestConnect(t *testing.T) { errBoom := errors.New("boom") @@ -371,7 +384,7 @@ func TestCreate(t *testing.T) { mg: &v1alpha1.User{}, }, want: want{ - err: errors.Wrap(errBoom, errCreateUser), + err: errors.Wrapf(errBoom, errCreateLogin, ""), }, }, "Success": { @@ -722,6 +735,9 @@ func TestDelete(t *testing.T) { reason: "Errors dropping a user should be returned", fields: fields{ db: &mockDB{ + MockQuery: func(ctx context.Context, q xsql.Query) (*sql.Rows, error) { + return mockRowsToSQLRows(sqlmock.NewRows([]string{})), nil + }, MockExec: func(ctx context.Context, q xsql.Query) error { return errBoom }, @@ -730,12 +746,15 @@ func TestDelete(t *testing.T) { args: args{ mg: &v1alpha1.User{}, }, - want: errors.Wrap(errBoom, errDropUser), + want: errors.Wrapf(errBoom, errDropUser, ""), }, "Success": { reason: "No error should be returned", fields: fields{ db: &mockDB{ + MockQuery: func(ctx context.Context, q xsql.Query) (*sql.Rows, error) { + return mockRowsToSQLRows(sqlmock.NewRows([]string{})), nil + }, MockExec: func(ctx context.Context, q xsql.Query) error { return nil },