Skip to content

Commit

Permalink
Update msg_server.go
Browse files Browse the repository at this point in the history
  • Loading branch information
crStiv authored Jan 11, 2025
1 parent 3cd1a78 commit 5e8b04e
Showing 1 changed file with 234 additions and 2 deletions.
236 changes: 234 additions & 2 deletions x/ccv/provider/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

errorsmod "cosmossdk.io/errors"
"cosmossdk.io/math"

cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -19,6 +20,28 @@ import (
ccvtypes "github.com/cosmos/interchain-security/v6/x/ccv/types"
)

// validateDeprecatedChainId validates that the chain ID is not provided (deprecated field)
func validateDeprecatedChainId(chainId string) error {
if chainId != "" {
return fmt.Errorf("chain ID is deprecated, use consumer ID instead")
}
return nil
}

// validateProviderAddress validates that the provider address matches the signer
func validateProviderAddress(addr, signer string) error {
if addr == "" {
return fmt.Errorf("empty provider address")
}
if signer == "" {
return fmt.Errorf("empty signer address")
}
if addr != signer {
return fmt.Errorf("provider address %s does not match signer %s", addr, signer)
}
return nil
}

type msgServer struct {
*Keeper
}
Expand Down Expand Up @@ -52,6 +75,26 @@ func (k msgServer) UpdateParams(goCtx context.Context, msg *types.MsgUpdateParam
func (k msgServer) AssignConsumerKey(goCtx context.Context, msg *types.MsgAssignConsumerKey) (*types.MsgAssignConsumerKeyResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ProviderAddr: %s", err.Error())
}

if msg.ConsumerKey == "" {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ConsumerKey cannot be empty")
}
if _, _, err := types.ParseConsumerKeyFromJson(msg.ConsumerKey); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ConsumerKey: %s", err.Error())
}

providerValidatorAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -109,6 +152,34 @@ func (k msgServer) ChangeRewardDenoms(goCtx context.Context, msg *types.MsgChang
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

// Validate basic message properties
emptyDenomsToAdd := len(msg.DenomsToAdd) == 0
emptyDenomsToRemove := len(msg.DenomsToRemove) == 0
// Return error if both sets are empty or nil
if emptyDenomsToAdd && emptyDenomsToRemove {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms, "both DenomsToAdd and DenomsToRemove are empty")
}

denomMap := map[string]struct{}{}
for _, denom := range msg.DenomsToAdd {
// validate the denom
if !sdk.NewCoin(denom, math.NewInt(1)).IsValid() {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms, "DenomsToAdd: invalid denom(%s)", denom)
}
denomMap[denom] = struct{}{}
}
for _, denom := range msg.DenomsToRemove {
// validate the denom
if !sdk.NewCoin(denom, math.NewInt(1)).IsValid() {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms, "DenomsToRemove: invalid denom(%s)", denom)
}
// denom cannot be in both sets
if _, found := denomMap[denom]; found {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms,
"denom(%s) cannot be both added and removed", denom)
}
}

eventAttributes := k.Keeper.ChangeRewardDenoms(ctx, msg.DenomsToAdd, msg.DenomsToRemove)

ctx.EventManager().EmitEvent(
Expand All @@ -123,6 +194,16 @@ func (k msgServer) ChangeRewardDenoms(goCtx context.Context, msg *types.MsgChang

func (k msgServer) SubmitConsumerMisbehaviour(goCtx context.Context, msg *types.MsgSubmitConsumerMisbehaviour) (*types.MsgSubmitConsumerMisbehaviourResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerMisbehaviour, "ConsumerId: %s", err.Error())
}

if err := msg.Misbehaviour.ValidateBasic(); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerMisbehaviour, "Misbehaviour: %s", err.Error())
}

if err := k.Keeper.HandleConsumerMisbehaviour(ctx, msg.ConsumerId, *msg.Misbehaviour); err != nil {
return nil, err
}
Expand All @@ -147,6 +228,23 @@ func (k msgServer) SubmitConsumerMisbehaviour(goCtx context.Context, msg *types.
func (k msgServer) SubmitConsumerDoubleVoting(goCtx context.Context, msg *types.MsgSubmitConsumerDoubleVoting) (*types.MsgSubmitConsumerDoubleVotingResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if dve, err := tmtypes.DuplicateVoteEvidenceFromProto(msg.DuplicateVoteEvidence); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "DuplicateVoteEvidence: %s", err.Error())
} else {
if err = dve.ValidateBasic(); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "DuplicateVoteEvidence: %s", err.Error())
}
}

if err := types.ValidateHeaderForConsumerDoubleVoting(msg.InfractionBlockHeader); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "ValidateTendermintHeader: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "ConsumerId: %s", err.Error())
}

evidence, err := tmtypes.DuplicateVoteEvidenceFromProto(msg.DuplicateVoteEvidence)
if err != nil {
return nil, err
Expand Down Expand Up @@ -198,6 +296,25 @@ func (k msgServer) SubmitConsumerDoubleVoting(goCtx context.Context, msg *types.
func (k msgServer) OptIn(goCtx context.Context, msg *types.MsgOptIn) (*types.MsgOptInResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ProviderAddr: %s", err.Error())
}

if msg.ConsumerKey != "" {
if _, _, err := types.ParseConsumerKeyFromJson(msg.ConsumerKey); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ConsumerKey: %s", err.Error())
}
}

valAddress, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -250,6 +367,19 @@ func (k msgServer) OptIn(goCtx context.Context, msg *types.MsgOptIn) (*types.Msg
func (k msgServer) OptOut(goCtx context.Context, msg *types.MsgOptOut) (*types.MsgOptOutResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptOut, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptOut, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptOut, "ProviderAddr: %s", err.Error())
}

valAddress, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -300,13 +430,30 @@ func (k msgServer) OptOut(goCtx context.Context, msg *types.MsgOptOut) (*types.M
func (k msgServer) SetConsumerCommissionRate(goCtx context.Context, msg *types.MsgSetConsumerCommissionRate) (*types.MsgSetConsumerCommissionRateResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

providerValidatorAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "ProviderAddr: %s", err.Error())
}

if !msg.Commission.IsPositive() || msg.Commission.GT(math.LegacyOneDec()) {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "commission must be between 0 and 1")
}

valAddress, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
}

// validator must already be registered
validator, err := k.stakingKeeper.GetValidator(ctx, providerValidatorAddr)
validator, err := k.stakingKeeper.GetValidator(ctx, valAddress)
if err != nil {
return nil, stakingtypes.ErrNoValidatorFound
}
Expand Down Expand Up @@ -350,6 +497,36 @@ func (k msgServer) SetConsumerCommissionRate(goCtx context.Context, msg *types.M
// CreateConsumer creates a consumer chain
func (k msgServer) CreateConsumer(goCtx context.Context, msg *types.MsgCreateConsumer) (*types.MsgCreateConsumerResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := types.ValidateChainId("ChainId", msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "ChainId: %s", err.Error())
}

if err := types.ValidateConsumerMetadata(msg.Metadata); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "Metadata: %s", err.Error())
}

if err := types.ValidateInitializationParameters(*msg.InitializationParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "InitializationParameters: %s", err.Error())
}

if err := types.ValidatePowerShapingParameters(*msg.PowerShapingParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "PowerShapingParameters: %s", err.Error())
}

if err := types.ValidateAllowlistedRewardDenoms(*msg.AllowlistedRewardDenoms); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "AllowlistedRewardDenoms: %s", err.Error())
}

if err := types.ValidateInfractionParameters(*msg.InfractionParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "InfractionParameters: %s", err.Error())
}

if k.GetAuthority() != msg.Authority {
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

resp := types.MsgCreateConsumerResponse{}

// initialize an empty slice to store event attributes
Expand Down Expand Up @@ -470,6 +647,52 @@ func (k msgServer) CreateConsumer(goCtx context.Context, msg *types.MsgCreateCon
// UpdateConsumer updates the metadata, power-shaping or initialization parameters of a consumer chain
func (k msgServer) UpdateConsumer(goCtx context.Context, msg *types.MsgUpdateConsumer) (*types.MsgUpdateConsumerResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "ConsumerId: %s", err.Error())
}

if msg.Metadata != nil {
if err := types.ValidateConsumerMetadata(*msg.Metadata); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "Metadata: %s", err.Error())
}
}

if msg.InitializationParameters != nil {
if err := types.ValidateInitializationParameters(*msg.InitializationParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "InitializationParameters: %s", err.Error())
}
}

if msg.PowerShapingParameters != nil {
if err := types.ValidatePowerShapingParameters(*msg.PowerShapingParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "PowerShapingParameters: %s", err.Error())
}
}

if msg.AllowlistedRewardDenoms != nil {
if err := types.ValidateAllowlistedRewardDenoms(*msg.AllowlistedRewardDenoms); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "AllowlistedRewardDenoms: %s", err.Error())
}
}

if msg.InfractionParameters != nil {
if err := types.ValidateInfractionParameters(*msg.InfractionParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "InfractionParameters: %s", err.Error())
}
}

if msg.NewChainId != "" {
if err := types.ValidateChainId("NewChainId", msg.NewChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "NewChainId: %s", err.Error())
}
}

if k.GetAuthority() != msg.Authority {
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

resp := types.MsgUpdateConsumerResponse{}

// initialize an empty slice to store event attributes
Expand Down Expand Up @@ -704,6 +927,15 @@ func (k msgServer) UpdateConsumer(goCtx context.Context, msg *types.MsgUpdateCon
func (k msgServer) RemoveConsumer(goCtx context.Context, msg *types.MsgRemoveConsumer) (*types.MsgRemoveConsumerResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgRemoveConsumer, "ConsumerId: %s", err.Error())
}

if k.GetAuthority() != msg.Authority {
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

resp := types.MsgRemoveConsumerResponse{}

consumerId := msg.ConsumerId
Expand Down

0 comments on commit 5e8b04e

Please sign in to comment.