diff --git a/description/src/type_example/rust_value.rs b/description/src/type_example/rust_value.rs index feaafd0..018d264 100644 --- a/description/src/type_example/rust_value.rs +++ b/description/src/type_example/rust_value.rs @@ -58,7 +58,7 @@ impl<'a> CodeTransformer<'a> { let has_unused_type_params = self .state .type_generator - .create_type_ir(ty) + .create_type_ir(ty, &Default::default()) // Note: derives not important here. .map_err(|e| anyhow!("{e}"))? .map(|e| e.type_params.has_unused_type_params()) .unwrap_or(false); diff --git a/typegen/src/tests/mod.rs b/typegen/src/tests/mod.rs index 61f02c6..4b687e6 100644 --- a/typegen/src/tests/mod.rs +++ b/typegen/src/tests/mod.rs @@ -969,6 +969,7 @@ fn apply_user_defined_derives_for_specific_types() { parse_quote!(scale_typegen::tests::B), [parse_quote!(Hash)], [parse_quote!(#[some_attribute])], + false, ); settings.derives.extend_for_type( parse_quote!(scale_typegen::tests::C), @@ -978,6 +979,7 @@ fn apply_user_defined_derives_for_specific_types() { parse_quote!(PartialOrd), ], [], + false, ); let code = Testgen::new().with::().gen_tests_mod(settings); @@ -1028,6 +1030,7 @@ fn opt_out_from_default_derives() { parse_quote!(scale_typegen::tests::B), vec![parse_quote!(Hash)], vec![parse_quote!(#[some_other_attribute])], + false, ); let settings = TypeGeneratorSettings { diff --git a/typegen/src/typegen/mod.rs b/typegen/src/typegen/mod.rs index 04f7bba..a3c90e6 100644 --- a/typegen/src/typegen/mod.rs +++ b/typegen/src/typegen/mod.rs @@ -3,7 +3,7 @@ use crate::{Derives, TypegenError}; use self::{ ir::module_ir::ModuleIR, ir::type_ir::{CompositeFieldIR, CompositeIR, CompositeIRKind, EnumIR, TypeIR, TypeIRKind}, - settings::TypeGeneratorSettings, + settings::{derives::FlatDerivesRegistry, TypeGeneratorSettings}, type_params::TypeParameters, type_path::TypeParameter, }; @@ -45,6 +45,12 @@ impl<'a> TypeGenerator<'a> { /// Generate a module containing all types defined in the supplied type registry. pub fn generate_types_mod(&self) -> Result { + let flat_derives_registry = self + .settings + .derives + .clone() + .flatten_recursive_derives(self.type_registry)?; + let mut root_mod = ModuleIR::new( self.settings.types_mod_ident.clone(), self.settings.types_mod_ident.clone(), @@ -65,7 +71,7 @@ impl<'a> TypeGenerator<'a> { } // if the type is not a builtin type, insert it into the respective module - if let Some(type_ir) = self.create_type_ir(&ty.ty)? { + if let Some(type_ir) = self.create_type_ir(&ty.ty, &flat_derives_registry)? { // Create the module this type should go into let innermost_module = root_mod.get_or_insert_submodule(namespace); innermost_module.types.insert(path.clone(), type_ir); @@ -75,7 +81,11 @@ impl<'a> TypeGenerator<'a> { Ok(root_mod) } - pub fn create_type_ir(&self, ty: &Type) -> Result, TypegenError> { + pub fn create_type_ir( + &self, + ty: &Type, + flat_derives_registry: &FlatDerivesRegistry, + ) -> Result, TypegenError> { // if the type is some builtin, early return, we are only interested in generating structs and enums. if !matches!(ty.type_def, TypeDef::Composite(_) | TypeDef::Variant(_)) { return Ok(None); @@ -124,7 +134,7 @@ impl<'a> TypeGenerator<'a> { _ => unreachable!("Other variants early return before. qed."), }; - let mut derives = self.type_derives(ty)?; + let mut derives = flat_derives_registry.resolve_derives_for_type(ty)?; if could_derive_as_compact { self.add_as_compact_derive(&mut derives); } @@ -237,13 +247,6 @@ impl<'a> TypeGenerator<'a> { self.settings.derives.default_derives() } - pub fn type_derives(&self, ty: &Type) -> Result { - let joined_path = ty.path.segments.join("::"); - let ty_path: syn::TypePath = syn::parse_str(&joined_path)?; - let derives = self.settings.derives.resolve(&ty_path); - Ok(derives) - } - /// Adds a AsCompact derive, if a path to AsCompact trait/derive macro set in settings. fn add_as_compact_derive(&self, derives: &mut Derives) { if let Some(compact_as_type_path) = &self.settings.compact_as_type_path { diff --git a/typegen/src/typegen/settings/derives.rs b/typegen/src/typegen/settings/derives.rs index 81337d7..adda25c 100644 --- a/typegen/src/typegen/settings/derives.rs +++ b/typegen/src/typegen/settings/derives.rs @@ -1,6 +1,9 @@ use std::collections::{HashMap, HashSet}; use quote::ToTokens; +use scale_info::{form::PortableForm, PortableRegistry, Type}; + +use crate::{utils::syn_type_path, TypegenError}; /// A struct containing the derives that we'll be applying to types; /// a combination of some common derives for all types, plus type @@ -9,6 +12,7 @@ use quote::ToTokens; pub struct DerivesRegistry { default_derives: Derives, specific_type_derives: HashMap, + recursive_type_derives: HashMap, } impl DerivesRegistry { @@ -33,8 +37,13 @@ impl DerivesRegistry { ty: syn::TypePath, derives: impl IntoIterator, attributes: impl IntoIterator, + recursive: bool, ) { - let type_derives = self.specific_type_derives.entry(ty).or_default(); + let type_derives = if recursive { + self.recursive_type_derives.entry(ty).or_default() + } else { + self.specific_type_derives.entry(ty).or_default() + }; type_derives.derives.extend(derives); type_derives.attributes.extend(attributes); } @@ -43,18 +52,6 @@ impl DerivesRegistry { pub fn default_derives(&self) -> &Derives { &self.default_derives } - - /// Resolve the derives for a generated type. Includes: - /// - The default derives for all types e.g. `scale::Encode, scale::Decode` - /// - Any user-defined derives for all types via `generated_type_derives` - /// - Any user-defined derives for this specific type - pub fn resolve(&self, ty: &syn::TypePath) -> Derives { - let mut resolved_derives = self.default_derives.clone(); - if let Some(specific) = self.specific_type_derives.get(ty) { - resolved_derives.extend_from(specific.clone()); - } - resolved_derives - } } /// A struct storing the set of derives and derive attributes that we'll apply @@ -136,3 +133,148 @@ impl ToTokens for Derives { } } } + +/// This is like a DerivesRegistry, but the recursive type derives have been flattened out into the specific_type_derives. +#[derive(Debug, Clone, Default)] +pub struct FlatDerivesRegistry { + default_derives: Derives, + specific_type_derives: HashMap, +} + +impl FlatDerivesRegistry { + /// Resolve the derives for a specific type. + pub fn resolve(&self, ty: &syn::TypePath) -> Derives { + let mut resolved_derives = self.default_derives.clone(); + if let Some(specific) = self.specific_type_derives.get(ty) { + resolved_derives.extend_from(specific.clone()); + } + resolved_derives + } + + pub fn resolve_derives_for_type( + &self, + ty: &Type, + ) -> Result { + Ok(self.resolve(&syn_type_path(ty)?)) + } +} + +impl DerivesRegistry { + pub fn flatten_recursive_derives( + self, + types: &PortableRegistry, + ) -> Result { + let DerivesRegistry { + default_derives, + mut specific_type_derives, + mut recursive_type_derives, + } = self; + + if recursive_type_derives.is_empty() { + return Ok(FlatDerivesRegistry { + default_derives, + specific_type_derives, + }); + } + + // Build a mapping of type ids to syn paths for all types in the registry: + let mut syn_path_for_id: HashMap = types + .types + .iter() + .map(|t| { + let path = syn_type_path(&t.ty)?; + Ok((t.id, path)) + }) + .collect::>()?; + + // Create an empty map of derives that we are about to fill: + let mut add_derives_for_id: HashMap = HashMap::new(); + + // Check for each type in the registry if it is the top level of + for ty in types.types.iter() { + let path = syn_path_for_id.get(&ty.id).expect("inserted above; qed;"); + let Some(recursive_derives) = recursive_type_derives.remove(&path) else { + continue; + }; + // The collected_type_ids contain the id of the type itself and all ids of its fields: + let mut collected_type_ids: HashSet = HashSet::new(); + collect_type_ids(ty.id, types, &mut collected_type_ids); + + // We collect the derives for each type id in the add_derives_for_id HashMap. + for id in collected_type_ids { + add_derives_for_id + .entry(id) + .or_default() + .extend_from(recursive_derives.clone()); + } + } + + // Merge all the recursively obtained derives with the existing derives for the types. + for (id, derived_to_add) in add_derives_for_id { + let path = syn_path_for_id + .remove(&id) + .expect("syn_path_for_id contains all type ids; qed;"); + specific_type_derives + .entry(path) + .or_default() + .extend_from(derived_to_add); + } + + Ok(FlatDerivesRegistry { + default_derives, + specific_type_derives, + }) + } +} + +fn collect_type_ids(id: u32, types: &PortableRegistry, collected_types: &mut HashSet) { + // Recursion protection: + if collected_types.contains(&id) { + return; + } + + // Add the type id itself as well: + collected_types.insert(id); + let ty = types + .resolve(id) + .expect("Should contain this id, if Registry not corrupted"); + + // Collect the types that are passed as type params (Question/Note: Is this necessary? Maybe not...) + for param in ty.type_params.iter() { + if let Some(id) = param.ty.map(|e| e.id) { + collect_type_ids(id, types, collected_types); + } + } + + // Collect ids depending on the types structure: + match &ty.type_def { + scale_info::TypeDef::Composite(def) => { + for f in def.fields.iter() { + collect_type_ids(f.ty.id, types, collected_types); + } + } + scale_info::TypeDef::Variant(def) => { + for v in def.variants.iter() { + for f in v.fields.iter() { + collect_type_ids(f.ty.id, types, collected_types); + } + } + } + scale_info::TypeDef::Sequence(def) => { + collect_type_ids(def.type_param.id, types, collected_types); + } + scale_info::TypeDef::Array(def) => { + collect_type_ids(def.type_param.id, types, collected_types); + } + scale_info::TypeDef::Tuple(def) => { + for f in def.fields.iter() { + collect_type_ids(f.id, types, collected_types); + } + } + scale_info::TypeDef::Primitive(_) => {} + scale_info::TypeDef::Compact(def) => { + collect_type_ids(def.type_param.id, types, collected_types); + } + scale_info::TypeDef::BitSequence(_) => {} + } +} diff --git a/typegen/src/typegen/settings/mod.rs b/typegen/src/typegen/settings/mod.rs index 22e5e09..df6c45a 100644 --- a/typegen/src/typegen/settings/mod.rs +++ b/typegen/src/typegen/settings/mod.rs @@ -1,5 +1,4 @@ use derives::DerivesRegistry; - use proc_macro2::Ident; use substitutes::TypeSubstitutes; use syn::parse_quote; diff --git a/typegen/src/utils.rs b/typegen/src/utils.rs index d724e4c..185ecc5 100644 --- a/typegen/src/utils.rs +++ b/typegen/src/utils.rs @@ -2,6 +2,14 @@ use scale_info::{form::PortableForm, Field, PortableRegistry, Type, TypeDef, Typ use smallvec::{smallvec, SmallVec}; use std::collections::HashMap; +use crate::TypegenError; + +pub fn syn_type_path(ty: &Type) -> Result { + let joined_path = ty.path.segments.join("::"); + let ty_path: syn::TypePath = syn::parse_str(&joined_path)?; + Ok(ty_path) +} + pub fn ensure_unique_type_paths(types: &mut PortableRegistry) { let mut types_with_same_type_path = HashMap::<&[String], SmallVec<[u32; 2]>>::new();