Skip to content

Commit

Permalink
fix(encryption): generated private key may already exist (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
JeremyPansier authored Oct 29, 2022
1 parent 1dfb34e commit 1aa7d6d
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 77 deletions.
8 changes: 0 additions & 8 deletions src/node/encryption/private_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ type PrivateKey struct {
*ecdsa.PrivateKey
}

func NewPrivateKey() (*PrivateKey, error) {
privateKey, err := crypto.GenerateKey()
if err != nil {
return nil, err
}
return &PrivateKey{privateKey}, err
}

func DecodePrivateKey(privateKeyString string) (*PrivateKey, error) {
bytes, err := hexutil.Decode(privateKeyString)
if err != nil {
Expand Down
20 changes: 14 additions & 6 deletions src/node/encryption/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type Wallet struct {
address string
}

func NewWallet() (*Wallet, error) {
return DecodeWallet("", "", "", "")
func NewEmptyWallet() *Wallet {
return &Wallet{nil, nil, ""}
}

func DecodeWallet(mnemonicString string, derivationPath string, password string, privateKeyString string) (*Wallet, error) {
Expand All @@ -26,7 +26,7 @@ func DecodeWallet(mnemonicString string, derivationPath string, password string,
} else if privateKeyString != "" {
privateKey, err = DecodePrivateKey(privateKeyString)
} else {
privateKey, err = NewPrivateKey()
return NewEmptyWallet(), nil
}
if err != nil {
return nil, fmt.Errorf("failed to create private key: %w", err)
Expand All @@ -37,14 +37,22 @@ func DecodeWallet(mnemonicString string, derivationPath string, password string,
}

func (wallet *Wallet) MarshalJSON() ([]byte, error) {
var privateKey string
if wallet.privateKey != nil {
privateKey = wallet.privateKey.String()
}
var publicKey string
if wallet.publicKey != nil {
publicKey = wallet.publicKey.String()
}
return json.Marshal(struct {
PrivateKey string `json:"private_key"`
PublicKey string `json:"public_key"`
Address string `json:"address"`
}{
PrivateKey: wallet.privateKey.String(),
PublicKey: wallet.publicKey.String(),
Address: wallet.Address(),
PrivateKey: privateKey,
PublicKey: publicKey,
Address: wallet.address,
})
}

Expand Down
4 changes: 2 additions & 2 deletions src/node/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import (
const validationIntervalInSeconds = 60

func main() {
mnemonic := flag.String("mnemonic", environment.NewVariable("MNEMONIC").GetStringValue(""), "The mnemonic (optional)")
mnemonic := flag.String("mnemonic", environment.NewVariable("MNEMONIC").GetStringValue(""), "The mnemonic (required if the private key is not provided)")
derivationPath := flag.String("derivation-path", environment.NewVariable("DERIVATION_PATH").GetStringValue("m/44'/60'/0'/0/0"), "The derivation path (unused if the mnemonic is omitted)")
password := flag.String("password", environment.NewVariable("PASSWORD").GetStringValue(""), "The mnemonic password (unused if the mnemonic is omitted)")
privateKey := flag.String("private-key", environment.NewVariable("PRIVATE_KEY").GetStringValue(""), "The private key (will be generated if not provided)")
privateKey := flag.String("private-key", environment.NewVariable("PRIVATE_KEY").GetStringValue(""), "The private key (required if the mnemonic is not provided, unused if the mnemonic is provided)")
port := flag.Uint64("port", environment.NewVariable("PORT").GetUint64Value(network.DefaultPort), "The TCP port number for the protocol host node")
configurationPath := flag.String("configuration-path", environment.NewVariable("CONFIGURATION_PATH").GetStringValue("config"), "The configuration files path")
logLevel := flag.String("log-level", environment.NewVariable("LOG_LEVEL").GetStringValue("info"), "The log level")
Expand Down
4 changes: 2 additions & 2 deletions src/ui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
)

func main() {
mnemonic := flag.String("mnemonic", environment.NewVariable("MNEMONIC").GetStringValue(""), "The mnemonic (optional)")
mnemonic := flag.String("mnemonic", environment.NewVariable("MNEMONIC").GetStringValue(""), "The mnemonic (required if the private key is not provided)")
derivationPath := flag.String("derivation-path", environment.NewVariable("DERIVATION_PATH").GetStringValue("m/44'/60'/0'/0/0"), "The derivation path (unused if the mnemonic is omitted)")
password := flag.String("password", environment.NewVariable("PASSWORD").GetStringValue(""), "The mnemonic password (unused if the mnemonic is omitted)")
privateKey := flag.String("private-key", environment.NewVariable("PRIVATE_KEY").GetStringValue(""), "The private key (will be generated if not provided)")
privateKey := flag.String("private-key", environment.NewVariable("PRIVATE_KEY").GetStringValue(""), "The private key (required if the mnemonic is not provided, unused if the mnemonic is provided)")
port := flag.Uint64("port", environment.NewVariable("PORT").GetUint64Value(server.DefaultPort), "The TCP port number for the UI server")
hostIp := flag.String("host-ip", environment.NewVariable("HOST_IP").GetStringValue(""), "The blockchain host IP address")
hostPort := flag.Uint64("host-port", environment.NewVariable("HOST_PORT").GetUint64Value(network.DefaultPort), "The TCP port number for the protocol host node")
Expand Down
7 changes: 7 additions & 0 deletions test/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package test

const (
Mnemonic1 = "artist silver basket insane canvas top drill social reflect park fruit bless"
Mnemonic2 = "screen wrap color drop lady keep dwarf horror recipe gap ride garage"
DerivationPath = "m/44'/60'/0'/0/0"
)
46 changes: 0 additions & 46 deletions test/node/encryption/address_test.go

This file was deleted.

21 changes: 21 additions & 0 deletions test/node/encryption/mnemonic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package encryption

import (
"fmt"
"github.com/my-cloud/ruthenium/src/node/encryption"
"github.com/my-cloud/ruthenium/test"
"testing"
)

func Test_PrivateKeyFromMnemonic(t *testing.T) {
// Arrange
mnemonic := encryption.NewMnemonic(test.Mnemonic1)

// Act
privateKey, _ := mnemonic.PrivateKey(test.DerivationPath, "")

// Assert
expectedPrivateKey := "0x48913790c2bebc48417491f96a7e07ec94c76ccd0fe1562dc1749479d9715afd"
actualPrivateKey := privateKey.String()
test.Assert(t, actualPrivateKey == expectedPrivateKey, fmt.Sprintf("Wrong private key. Expected: %s - Actual: %s", expectedPrivateKey, actualPrivateKey))
}
21 changes: 21 additions & 0 deletions test/node/encryption/private_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package encryption

import (
"fmt"
"github.com/my-cloud/ruthenium/src/node/encryption"
"github.com/my-cloud/ruthenium/test"
"testing"
)

func Test_PublicKeyFromPrivateKey(t *testing.T) {
// Arrange
privateKey, _ := encryption.DecodePrivateKey("0x48913790c2bebc48417491f96a7e07ec94c76ccd0fe1562dc1749479d9715afd")

// Act
publicKey := encryption.NewPublicKey(privateKey)

// Assert
expectedPublicKey := "0x046bd857ce80ff5238d6561f3a775802453c570b6ea2cbf93a35a8a6542b2edbe5f625f9e3fbd2a5df62adebc27391332a265fb94340fb11b69cf569605a5df782"
actualPublicKey := publicKey.String()
test.Assert(t, actualPublicKey == expectedPublicKey, fmt.Sprintf("Wrong public key. Expected: %s - Actual: %s", expectedPublicKey, actualPublicKey))
}
20 changes: 20 additions & 0 deletions test/node/encryption/public_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package encryption

import (
"fmt"
"github.com/my-cloud/ruthenium/src/node/encryption"
"github.com/my-cloud/ruthenium/test"
"testing"
)

func Test_AddressFromPublicKey(t *testing.T) {
// Arrange
publicKey, _ := encryption.DecodePublicKey("0x046bd857ce80ff5238d6561f3a775802453c570b6ea2cbf93a35a8a6542b2edbe5f625f9e3fbd2a5df62adebc27391332a265fb94340fb11b69cf569605a5df782")

// Act
address := publicKey.Address()

// Assert
expectedAddress := "0x9C69443c3Ec0D660e257934ffc1754EB9aD039CB"
test.Assert(t, address == expectedAddress, fmt.Sprintf("Wrong address. Expected: %s - Actual: %s", expectedAddress, address))
}
53 changes: 53 additions & 0 deletions test/node/encryption/wallet_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package encryption

import (
"fmt"
"github.com/my-cloud/ruthenium/src/node/encryption"
"github.com/my-cloud/ruthenium/test"
"testing"
)

func Test_DecodeWallet_PrivateKeyProvided_ReturnsWalletForPrivateKey(t *testing.T) {
// Arrange
expectedPrivateKey := "0x48913790c2bebc48417491f96a7e07ec94c76ccd0fe1562dc1749479d9715afd"

// Act
wallet, _ := encryption.DecodeWallet("", "", "", expectedPrivateKey)

// Assert
actualPrivateKey := wallet.PrivateKey().String()
test.Assert(t, actualPrivateKey == expectedPrivateKey, fmt.Sprintf("Wrong private key. Expected: %s - Actual: %s", expectedPrivateKey, actualPrivateKey))
}

func Test_DecodeWallet_BothPrivateKeyAndMnemonicAreEmpty_ReturnsEmptyWallet(t *testing.T) {
// Act
wallet, _ := encryption.DecodeWallet("", "", "", "")

// Assert
test.Assert(t, wallet.PrivateKey() == nil, "Private key is not nil whereas it should be.")
}

func Test_MarshalJSON_ValidPrivateKey_ReturnsMarshaledJsonWithoutError(t *testing.T) {
// Arrange
privateKey := "0x48913790c2bebc48417491f96a7e07ec94c76ccd0fe1562dc1749479d9715afd"
wallet, _ := encryption.DecodeWallet("", "", "", privateKey)

// Act
marshaledWallet, err := wallet.MarshalJSON()

// Assert
test.Assert(t, marshaledWallet != nil, "Marshaled wallet is nil.")
test.Assert(t, err == nil, "Marshal wallet returned an error.")
}

func Test_MarshalJSON_EmptyWallet_ReturnsMarshaledJsonWithoutError(t *testing.T) {
// Arrange
wallet := encryption.NewEmptyWallet()

// Act
marshaledWallet, err := wallet.MarshalJSON()

// Assert
test.Assert(t, marshaledWallet != nil, "Marshaled wallet is nil.")
test.Assert(t, err == nil, "Marshal wallet returned an error.")
}
26 changes: 13 additions & 13 deletions test/node/protocol/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func Test_AddTransaction_TransactionTimestampIsAfterNow_TransactionNotAdded(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -44,7 +44,7 @@ func Test_AddTransaction_TransactionTimestampIsAfterNow_TransactionNotAdded(t *t

func Test_AddTransaction_TransactionTimestampIsOlderThan2Blocks_TransactionNotAdded(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -75,7 +75,7 @@ func Test_AddTransaction_TransactionTimestampIsOlderThan2Blocks_TransactionNotAd
}
func Test_AddTransaction_TransactionIsAlreadyInTheBlockchain_TransactionNotAdded(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -107,9 +107,9 @@ func Test_AddTransaction_TransactionIsAlreadyInTheBlockchain_TransactionNotAdded

func Test_AddTransaction_InvalidSignature_TransactionNotAdded(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
walletA, _ := encryption.NewWallet()
walletA, _ := encryption.DecodeWallet(test.Mnemonic2, test.DerivationPath, "", "")
walletAAddress := walletA.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -141,9 +141,9 @@ func Test_AddTransaction_InvalidSignature_TransactionNotAdded(t *testing.T) {

func Test_AddTransaction_ValidTransaction_TransactionAdded(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
walletA, _ := encryption.NewWallet()
walletA, _ := encryption.DecodeWallet(test.Mnemonic2, test.DerivationPath, "", "")
walletAAddress := walletA.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -175,9 +175,9 @@ func Test_AddTransaction_ValidTransaction_TransactionAdded(t *testing.T) {

func Test_Validate_InvalidSignature_TransactionNotValidated(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
walletA, _ := encryption.NewWallet()
walletA, _ := encryption.DecodeWallet(test.Mnemonic2, test.DerivationPath, "", "")
walletAAddress := walletA.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -207,7 +207,7 @@ func Test_Validate_InvalidSignature_TransactionNotValidated(t *testing.T) {

func Test_Validate_TransactionTimestampIsAfterNow_TransactionNotValidated(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -237,7 +237,7 @@ func Test_Validate_TransactionTimestampIsAfterNow_TransactionNotValidated(t *tes

func Test_Validate_TransactionTimestampIsOlderThan2Blocks_TransactionNotValidated(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -269,7 +269,7 @@ func Test_Validate_TransactionTimestampIsOlderThan2Blocks_TransactionNotValidate

func Test_Validate_TransactionIsAlreadyInTheBlockchain_TransactionNotValidated(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down Expand Up @@ -301,7 +301,7 @@ func Test_Validate_TransactionIsAlreadyInTheBlockchain_TransactionNotValidated(t

func Test_Validate_ValidTransaction_TransactionValidated(t *testing.T) {
// Arrange
validatorWallet, _ := encryption.NewWallet()
validatorWallet, _ := encryption.DecodeWallet(test.Mnemonic1, test.DerivationPath, "", "")
validatorWalletAddress := validatorWallet.Address()
registryMock := new(RegistryMock)
registryMock.IsRegisteredFunc = func(string) (bool, error) { return true, nil }
Expand Down

0 comments on commit 1aa7d6d

Please sign in to comment.