Skip to content

Commit

Permalink
fix MSSQL database user bug (#168)
Browse files Browse the repository at this point in the history
* fix mssql user bug

Signed-off-by: arielsepton <arielsepton1@gmail.com>

* revert test file

Signed-off-by: arielsepton <arielsepton1@gmail.com>

* fix test

Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com>

* gofmt -s -w

Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com>

* remove redundent print and edd better error messages

Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com>

---------

Signed-off-by: arielsepton <arielsepton1@gmail.com>
Signed-off-by: arielsepton <64063409+arielsepton@users.noreply.github.com>
  • Loading branch information
arielsepton authored Jan 11, 2024
1 parent 61a1672 commit 06461d7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
58 changes: 49 additions & 9 deletions pkg/controller/mssql/user/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ const (
errNoSecretRef = "ProviderConfig does not reference a credentials Secret"
errGetSecret = "cannot get credentials Secret"

errNotUser = "managed resource is not a User custom resource"
errSelectUser = "cannot select user"
errCreateUser = "cannot create user"
errDropUser = "cannot drop user"
errNotUser = "managed resource is not a User custom resource"
errSelectUser = "cannot select user"
errCreateUser = "cannot create user %s"
errCreateLogin = "cannot create login %s"
errDropUser = "error dropping user %s"
errDropLogin = "error dropping login %s"
errCannotGetLogins = "cannot get current logins"
errCannotKillLoginSession = "error killing session %d for login %s"

errUpdateUser = "cannot update user"
errGetPasswordSecretFailed = "cannot get password secret"

Expand Down Expand Up @@ -173,11 +178,19 @@ func (c *external) Create(ctx context.Context, mg resource.Managed) (managed.Ext
return managed.ExternalCreation{}, err
}
}
query := fmt.Sprintf("CREATE USER %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw))

loginQuery := fmt.Sprintf("CREATE LOGIN %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw))
if err := c.db.Exec(ctx, xsql.Query{
String: loginQuery,
}); err != nil {
return managed.ExternalCreation{}, errors.Wrapf(err, errCreateLogin, meta.GetExternalName(cr))
}

userQuery := fmt.Sprintf("CREATE USER %s FOR LOGIN %s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteIdentifier(meta.GetExternalName(cr)))
if err := c.db.Exec(ctx, xsql.Query{
String: query,
String: userQuery,
}); err != nil {
return managed.ExternalCreation{}, errors.Wrap(err, errCreateUser)
return managed.ExternalCreation{}, errors.Wrapf(err, errCreateUser, meta.GetExternalName(cr))
}

return managed.ExternalCreation{
Expand All @@ -197,7 +210,7 @@ func (c *external) Update(ctx context.Context, mg resource.Managed) (managed.Ext
}

if changed {
query := fmt.Sprintf("ALTER USER %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw))
query := fmt.Sprintf("ALTER LOGIN %s WITH PASSWORD=%s", mssql.QuoteIdentifier(meta.GetExternalName(cr)), mssql.QuoteValue(pw))
if err := c.db.Exec(ctx, xsql.Query{
String: query,
}); err != nil {
Expand All @@ -217,10 +230,37 @@ func (c *external) Delete(ctx context.Context, mg resource.Managed) error {
return errors.New(errNotUser)
}

query := fmt.Sprintf("SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = %s", mssql.QuoteValue(meta.GetExternalName(cr)))
rows, err := c.db.Query(ctx, xsql.Query{String: query})
if err != nil {
return errors.Wrap(err, errCannotGetLogins)
}
defer rows.Close() //nolint:errcheck

for rows.Next() {
var sessionID int
if err := rows.Scan(&sessionID); err != nil {
return errors.Wrap(err, errCannotGetLogins)
}
if err := c.db.Exec(ctx, xsql.Query{String: fmt.Sprintf("KILL %d", sessionID)}); err != nil {
return errors.Wrapf(err, errCannotKillLoginSession, sessionID, meta.GetExternalName(cr))
}
}
if err := rows.Err(); err != nil {
return errors.Wrap(err, errCannotGetLogins)
}

if err := c.db.Exec(ctx, xsql.Query{
String: fmt.Sprintf("DROP USER IF EXISTS %s", mssql.QuoteIdentifier(meta.GetExternalName(cr))),
}); err != nil {
return errors.Wrap(err, errDropUser)
return errors.Wrapf(err, errDropUser, meta.GetExternalName(cr))
}

if err := c.db.Exec(ctx, xsql.Query{
String: fmt.Sprintf("DROP LOGIN %s", mssql.QuoteIdentifier(meta.GetExternalName(cr))),
}); err != nil {
return errors.Wrapf(err, errDropLogin, meta.GetExternalName(cr))
}

return nil
}
25 changes: 22 additions & 3 deletions pkg/controller/mssql/user/reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"database/sql"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/crossplane-contrib/provider-sql/apis/mssql/v1alpha1"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand All @@ -42,6 +43,7 @@ type mockDB struct {
MockExec func(ctx context.Context, q xsql.Query) error
MockExecTx func(ctx context.Context, ql []xsql.Query) error
MockScan func(ctx context.Context, q xsql.Query, dest ...interface{}) error
MockQuery func(ctx context.Context, q xsql.Query) (*sql.Rows, error)
}

func (m mockDB) Exec(ctx context.Context, q xsql.Query) error {
Expand All @@ -54,7 +56,7 @@ func (m mockDB) Scan(ctx context.Context, q xsql.Query, dest ...interface{}) err
return m.MockScan(ctx, q, dest...)
}
func (m mockDB) Query(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
return &sql.Rows{}, nil
return m.MockQuery(ctx, q)
}
func (m mockDB) GetConnectionDetails(username, password string) managed.ConnectionDetails {
return managed.ConnectionDetails{
Expand All @@ -65,6 +67,17 @@ func (m mockDB) GetConnectionDetails(username, password string) managed.Connecti
}
}

func mockRowsToSQLRows(mockRows *sqlmock.Rows) *sql.Rows {
db, mock, _ := sqlmock.New()
mock.ExpectQuery("select").WillReturnRows(mockRows)
rows, err := db.Query("select")
if err != nil {
println("%v", err)
return nil
}
return rows
}

func TestConnect(t *testing.T) {
errBoom := errors.New("boom")

Expand Down Expand Up @@ -371,7 +384,7 @@ func TestCreate(t *testing.T) {
mg: &v1alpha1.User{},
},
want: want{
err: errors.Wrap(errBoom, errCreateUser),
err: errors.Wrapf(errBoom, errCreateLogin, ""),
},
},
"Success": {
Expand Down Expand Up @@ -722,6 +735,9 @@ func TestDelete(t *testing.T) {
reason: "Errors dropping a user should be returned",
fields: fields{
db: &mockDB{
MockQuery: func(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
return mockRowsToSQLRows(sqlmock.NewRows([]string{})), nil
},
MockExec: func(ctx context.Context, q xsql.Query) error {
return errBoom
},
Expand All @@ -730,12 +746,15 @@ func TestDelete(t *testing.T) {
args: args{
mg: &v1alpha1.User{},
},
want: errors.Wrap(errBoom, errDropUser),
want: errors.Wrapf(errBoom, errDropUser, ""),
},
"Success": {
reason: "No error should be returned",
fields: fields{
db: &mockDB{
MockQuery: func(ctx context.Context, q xsql.Query) (*sql.Rows, error) {
return mockRowsToSQLRows(sqlmock.NewRows([]string{})), nil
},
MockExec: func(ctx context.Context, q xsql.Query) error {
return nil
},
Expand Down

0 comments on commit 06461d7

Please sign in to comment.