Skip to content

Commit

Permalink
refactor with bugfix to keep account number on shorthand account crea…
Browse files Browse the repository at this point in the history
…tion (#132)
  • Loading branch information
beer-1 authored Dec 12, 2024
1 parent 05f54b7 commit e4c92d4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 71 deletions.
10 changes: 9 additions & 1 deletion x/evm/keeper/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, is
return common.BytesToAddress(addr.Bytes()), nil
}

accountNumber := uint64(0)
shorthandAddr := common.BytesToAddress(addr.Bytes())
if found := k.accountKeeper.HasAccount(ctx, shorthandAddr.Bytes()); found {
account := k.accountKeeper.GetAccount(ctx, shorthandAddr.Bytes())
Expand All @@ -36,17 +37,24 @@ func (k Keeper) convertToEVMAddress(ctx context.Context, addr sdk.AccAddress, is

return common.Address{}, types.ErrAddressAlreadyExists.Wrapf("failed to create shorthand account of `%s`: `%s`", addr, shorthandAddr)
}

accountNumber = account.GetAccountNumber()
}

if isSigner {
// if account number is not set, get next account number
if accountNumber == 0 {
accountNumber = k.accountKeeper.NextAccountNumber(ctx)
}

// create shorthand account
shorthandAccount, err := types.NewShorthandAccountWithAddress(k.ac, addr)
if err != nil {
return common.Address{}, err
}

// register shorthand account
shorthandAccount.AccountNumber = k.accountKeeper.NextAccountNumber(ctx)
shorthandAccount.AccountNumber = accountNumber
k.accountKeeper.SetAccount(ctx, shorthandAccount)
}

Expand Down
36 changes: 36 additions & 0 deletions x/evm/keeper/address_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,39 @@ func Test_AllowLongCosmosAddress(t *testing.T) {
))
require.ErrorContains(t, err, types.ErrAddressAlreadyExists.Error())
}

func Test_AllowLongCosmosAddress_ConvertEmptyAccount(t *testing.T) {
ctx, input := createDefaultTestInput(t)
_, _, addr := keyPubAddr()
_, _, addr2 := keyPubAddr()
evmAddr := common.BytesToAddress(addr.Bytes())
evmAddr2 := common.BytesToAddress(addr2.Bytes())

addr3 := append([]byte{0}, addr2.Bytes()...)

erc20Keeper, err := keeper.NewERC20Keeper(&input.EVMKeeper)
require.NoError(t, err)

// deploy erc20 contract
fooContractAddr := deployERC20(t, ctx, input, evmAddr, "foo")
fooDenom, err := types.ContractAddrToDenom(ctx, &input.EVMKeeper, fooContractAddr)
require.NoError(t, err)
require.Equal(t, "evm/"+fooContractAddr.Hex()[2:], fooDenom)

// mint erc20
mintERC20(t, ctx, input, evmAddr, evmAddr, sdk.NewCoin(fooDenom, math.NewInt(100)), false)

// create empty account
mintERC20(t, ctx, input, evmAddr, evmAddr2, sdk.NewCoin(fooDenom, math.NewInt(100)), false)
expectedAccNum := input.AccountKeeper.GetAccount(ctx, addr2).GetAccountNumber()

// take the address ownership
err = erc20Keeper.SendCoins(ctx, addr3, addr, sdk.NewCoins(
sdk.NewCoin(fooDenom, math.NewInt(50)),
))
require.NoError(t, err)

// account number should be the same
accNum := input.AccountKeeper.GetAccount(ctx, addr2).GetAccountNumber()
require.Equal(t, expectedAccNum, accNum)
}
130 changes: 60 additions & 70 deletions x/evm/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"context"
"errors"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/core/tracing"
coretypes "github.com/ethereum/go-ethereum/core/types"
"github.com/holiman/uint256"

"cosmossdk.io/collections"
"cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
Expand Down Expand Up @@ -39,44 +42,17 @@ func (ms *msgServerImpl) Create(ctx context.Context, msg *types.MsgCreate) (*typ
}

// argument validation
caller, err := ms.convertToEVMAddress(ctx, sender, true)
caller, codeBz, value, accessList, err := ms.validateArguments(ctx, sender, msg.Code, msg.Value, msg.AccessList, true)
if err != nil {
return nil, err
}
if len(msg.Code) == 0 {
return nil, sdkerrors.ErrInvalidRequest.Wrap("empty code bytes")
}
codeBz, err := hexutil.Decode(msg.Code)
if err != nil {
return nil, types.ErrInvalidHexString.Wrap(err.Error())
}
value, overflow := uint256.FromBig(msg.Value.BigInt())
if overflow {
return nil, types.ErrInvalidValue.Wrap("value is out of range")
}
accessList := types.ConvertCosmosAccessListToEth(msg.AccessList)

// check the sender is allowed publisher
params, err := ms.Params.Get(ctx)
err = ms.assertAllowedPublishers(ctx, msg.Sender)
if err != nil {
return nil, err
}

// assert deploy authorization
if len(params.AllowedPublishers) != 0 {
allowed := false
for _, publisher := range params.AllowedPublishers {
if msg.Sender == publisher {
allowed = true

break
}
}

if !allowed {
return nil, sdkerrors.ErrUnauthorized.Wrapf("`%s` is not allowed to deploy a contract", msg.Sender)
}
}

// deploy a contract
retBz, contractAddr, logs, err := ms.EVMCreate(ctx, caller, codeBz, value, accessList)
if err != nil {
Expand Down Expand Up @@ -104,44 +80,17 @@ func (ms *msgServerImpl) Create2(ctx context.Context, msg *types.MsgCreate2) (*t
}

// argument validation
caller, err := ms.convertToEVMAddress(ctx, sender, true)
caller, codeBz, value, accessList, err := ms.validateArguments(ctx, sender, msg.Code, msg.Value, msg.AccessList, true)
if err != nil {
return nil, err
}
if len(msg.Code) == 0 {
return nil, sdkerrors.ErrInvalidRequest.Wrap("empty code bytes")
}
codeBz, err := hexutil.Decode(msg.Code)
if err != nil {
return nil, types.ErrInvalidHexString.Wrap(err.Error())
}
value, overflow := uint256.FromBig(msg.Value.BigInt())
if overflow {
return nil, types.ErrInvalidValue.Wrap("value is out of range")
}
accessList := types.ConvertCosmosAccessListToEth(msg.AccessList)

// check the sender is allowed publisher
params, err := ms.Params.Get(ctx)
err = ms.assertAllowedPublishers(ctx, msg.Sender)
if err != nil {
return nil, err
}

// assert deploy authorization
if len(params.AllowedPublishers) != 0 {
allowed := false
for _, publisher := range params.AllowedPublishers {
if msg.Sender == publisher {
allowed = true

break
}
}

if !allowed {
return nil, sdkerrors.ErrUnauthorized.Wrapf("`%s` is not allowed to deploy a contract", msg.Sender)
}
}

// deploy a contract
retBz, contractAddr, logs, err := ms.EVMCreate2(ctx, caller, codeBz, value, msg.Salt, accessList)
if err != nil {
Expand Down Expand Up @@ -174,19 +123,10 @@ func (ms *msgServerImpl) Call(ctx context.Context, msg *types.MsgCall) (*types.M
}

// argument validation
caller, err := ms.convertToEVMAddress(ctx, sender, true)
caller, inputBz, value, accessList, err := ms.validateArguments(ctx, sender, msg.Input, msg.Value, msg.AccessList, false)
if err != nil {
return nil, err
}
inputBz, err := hexutil.Decode(msg.Input)
if err != nil {
return nil, types.ErrInvalidHexString.Wrap(err.Error())
}
value, overflow := uint256.FromBig(msg.Value.BigInt())
if overflow {
return nil, types.ErrInvalidValue.Wrap("value is out of range")
}
accessList := types.ConvertCosmosAccessListToEth(msg.AccessList)

retBz, logs, err := ms.EVMCall(ctx, caller, contractAddr, inputBz, value, accessList)
if err != nil {
Expand Down Expand Up @@ -291,3 +231,53 @@ func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sd

return nil
}

// validateArguments validates the arguments of create, create2, and call messages.
func (ms *msgServerImpl) validateArguments(
ctx context.Context, sender []byte, data string,
value math.Int, accessList []types.AccessTuple, isCreate bool,
) (common.Address, []byte, *uint256.Int, coretypes.AccessList, error) {
caller, err := ms.convertToEVMAddress(ctx, sender, true)
if err != nil {
return common.Address{}, nil, nil, nil, err
}
if isCreate && len(data) == 0 {
return common.Address{}, nil, nil, nil, sdkerrors.ErrInvalidRequest.Wrap("empty code bytes")
}
dataBz, err := hexutil.Decode(data)
if err != nil {
return common.Address{}, nil, nil, nil, types.ErrInvalidHexString.Wrap(err.Error())
}
val, overflow := uint256.FromBig(value.BigInt())
if overflow {
return common.Address{}, nil, nil, nil, types.ErrInvalidValue.Wrap("value is out of range")
}

return caller, dataBz, val, types.ConvertCosmosAccessListToEth(accessList), nil
}

// assertAllowedPublishers asserts the sender is allowed to deploy a contract.
func (ms *msgServerImpl) assertAllowedPublishers(ctx context.Context, sender string) error {
params, err := ms.Params.Get(ctx)
if err != nil {
return err
}

// assert deploy authorization
if len(params.AllowedPublishers) != 0 {
allowed := false
for _, publisher := range params.AllowedPublishers {
if sender == publisher {
allowed = true

break
}
}

if !allowed {
return sdkerrors.ErrUnauthorized.Wrapf("`%s` is not allowed to deploy a contract", sender)
}
}

return nil
}

0 comments on commit e4c92d4

Please sign in to comment.