diff --git a/description/src/description.rs b/description/src/description.rs index b477f7c..02e8d4d 100644 --- a/description/src/description.rs +++ b/description/src/description.rs @@ -4,7 +4,7 @@ use scale_info::{ TypeDefCompact, TypeDefPrimitive, TypeDefSequence, TypeDefTuple, TypeDefVariant, Variant, }; -use crate::{description, transformer::Transformer}; +use crate::transformer::Transformer; use super::formatting::format_type_description; diff --git a/typegen/src/tests/mod.rs b/typegen/src/tests/mod.rs index 4b687e6..dcb0427 100644 --- a/typegen/src/tests/mod.rs +++ b/typegen/src/tests/mod.rs @@ -1012,7 +1012,103 @@ fn apply_user_defined_derives_for_specific_types() { } #[test] -fn opt_out_from_default_derives() { +fn apply_recursive_derives() { + use std::collections::BTreeMap; + + #[allow(unused)] + #[derive(TypeInfo)] + struct Human { + organ_status: BTreeMap, + profession: Profession, + organs: Vec, + } + + #[allow(unused)] + #[derive(TypeInfo)] + enum Organ { + Heart, + Stomach, + } + + #[allow(unused)] + #[derive(TypeInfo)] + struct Status { + damage: Compact, + } + + #[allow(unused)] + #[derive(TypeInfo)] + enum Profession { + Student { college: String }, + Programmer, + } + + let mut derives = DerivesRegistry::new(); + + derives.extend_for_type( + parse_quote!(scale_typegen::tests::Human), + vec![parse_quote!(Reflect)], + vec![parse_quote!(#[is_human])], + false, + ); + + derives.extend_for_type( + parse_quote!(scale_typegen::tests::Human), + vec![parse_quote!(Clone)], + vec![parse_quote!(#[is_nice])], + true, + ); + + let settings = TypeGeneratorSettings { + derives, + ..subxt_settings() + }; + let code = Testgen::new().with::().gen_tests_mod(settings); + + let expected_code = quote! { + pub mod tests { + use super::root; + #[derive(Clone, Reflect)] + #[is_human] + #[is_nice] + pub struct Human { + pub organ_status: ::subxt_path::utils::KeyedVec< + root::scale_typegen::tests::Organ, + root::scale_typegen::tests::Status + >, + pub profession: root::scale_typegen::tests::Profession, + pub organs: ::std::vec::Vec, + } + #[derive(Clone)] + #[is_nice] + pub enum Organ { + #[codec(index = 0)] + Heart, + #[codec(index = 1)] + Stomach, + } + #[derive(Clone)] + #[is_nice] + pub enum Profession { + #[codec(index = 0)] + Student { college: ::std::string::String , }, + #[codec(index = 1)] + Programmer, + } + #[derive(Clone)] + #[is_nice] + pub struct Status { + #[codec(compact)] + pub damage: ::core::primitive::u32, + } + } + }; + + assert_eq!(code.to_string(), expected_code.to_string()); +} + +#[test] +fn apply_derives() { #[allow(unused)] #[derive(TypeInfo)] struct A(B); @@ -1060,7 +1156,7 @@ fn opt_out_from_default_derives() { /// By default a BTreeMap would be replaced by a KeyedVec. /// This test demonstrates that it does not happen if we opt out of default type substitutes. #[test] -fn opt_out_from_default_substitutes() { +fn apply_custom_substitutes() { use std::collections::BTreeMap; #[allow(unused)] diff --git a/typegen/src/typegen/settings/derives.rs b/typegen/src/typegen/settings/derives.rs index adda25c..1f60b13 100644 --- a/typegen/src/typegen/settings/derives.rs +++ b/typegen/src/typegen/settings/derives.rs @@ -32,6 +32,9 @@ impl DerivesRegistry { } /// Insert derives to be applied to a specific generated type. + /// + /// The `recursive` flag can be set if child types should also receive the given derives/attributes. + /// Child types are all types that are mentioned as fields or type parameters of the type. pub fn extend_for_type( &mut self, ty: syn::TypePath, @@ -134,7 +137,9 @@ impl ToTokens for Derives { } } -/// This is like a DerivesRegistry, but the recursive type derives have been flattened out into the specific_type_derives. +/// This is like a DerivesRegistry, but the recursive type derives have been flattened out into specific_type_derives. +/// +/// Can be constructed properly using a DerivesRegistry and a PortableRegistry with `DerivesRegistry::flatten_recursive_derives()`. #[derive(Debug, Clone, Default)] pub struct FlatDerivesRegistry { default_derives: Derives, @@ -142,7 +147,7 @@ pub struct FlatDerivesRegistry { } impl FlatDerivesRegistry { - /// Resolve the derives for a specific type. + /// Resolve the derives for a specific type path. pub fn resolve(&self, ty: &syn::TypePath) -> Derives { let mut resolved_derives = self.default_derives.clone(); if let Some(specific) = self.specific_type_derives.get(ty) { @@ -151,6 +156,7 @@ impl FlatDerivesRegistry { resolved_derives } + /// Resolve the derives for a specific type. pub fn resolve_derives_for_type( &self, ty: &Type, @@ -160,6 +166,8 @@ impl FlatDerivesRegistry { } impl DerivesRegistry { + /// Flattens out the recursive derives into specific derives. + /// For this it needs to have a PortableRegistry that it can traverse recursively. pub fn flatten_recursive_derives( self, types: &PortableRegistry, @@ -181,9 +189,15 @@ impl DerivesRegistry { let mut syn_path_for_id: HashMap = types .types .iter() - .map(|t| { - let path = syn_type_path(&t.ty)?; - Ok((t.id, path)) + .filter_map(|t| { + if t.ty.path.is_empty() { + None + } else { + match syn_type_path(&t.ty) { + Ok(path) => Some(Ok((t.id, path))), + Err(err) => Some(Err(err)), + } + } }) .collect::>()?; @@ -192,8 +206,11 @@ impl DerivesRegistry { // Check for each type in the registry if it is the top level of for ty in types.types.iter() { - let path = syn_path_for_id.get(&ty.id).expect("inserted above; qed;"); - let Some(recursive_derives) = recursive_type_derives.remove(&path) else { + let Some(path) = syn_path_for_id.get(&ty.id) else { + // this is only the case for types with empty path (i.e. builtin types). + continue; + }; + let Some(recursive_derives) = recursive_type_derives.remove(path) else { continue; }; // The collected_type_ids contain the id of the type itself and all ids of its fields: @@ -211,13 +228,12 @@ impl DerivesRegistry { // Merge all the recursively obtained derives with the existing derives for the types. for (id, derived_to_add) in add_derives_for_id { - let path = syn_path_for_id - .remove(&id) - .expect("syn_path_for_id contains all type ids; qed;"); - specific_type_derives - .entry(path) - .or_default() - .extend_from(derived_to_add); + if let Some(path) = syn_path_for_id.remove(&id) { + specific_type_derives + .entry(path) + .or_default() + .extend_from(derived_to_add); + } } Ok(FlatDerivesRegistry {