Skip to content

Commit

Permalink
recursive derives
Browse files Browse the repository at this point in the history
  • Loading branch information
tadeohepperle committed Nov 22, 2023
1 parent 874e795 commit ca78ad5
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 26 deletions.
2 changes: 1 addition & 1 deletion description/src/type_example/rust_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl<'a> CodeTransformer<'a> {
let has_unused_type_params = self
.state
.type_generator
.create_type_ir(ty)
.create_type_ir(ty, &Default::default()) // Note: derives not important here.
.map_err(|e| anyhow!("{e}"))?
.map(|e| e.type_params.has_unused_type_params())
.unwrap_or(false);
Expand Down
3 changes: 3 additions & 0 deletions typegen/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ fn apply_user_defined_derives_for_specific_types() {
parse_quote!(scale_typegen::tests::B),
[parse_quote!(Hash)],
[parse_quote!(#[some_attribute])],
false,
);
settings.derives.extend_for_type(
parse_quote!(scale_typegen::tests::C),
Expand All @@ -978,6 +979,7 @@ fn apply_user_defined_derives_for_specific_types() {
parse_quote!(PartialOrd),
],
[],
false,
);
let code = Testgen::new().with::<A>().gen_tests_mod(settings);

Expand Down Expand Up @@ -1028,6 +1030,7 @@ fn opt_out_from_default_derives() {
parse_quote!(scale_typegen::tests::B),
vec![parse_quote!(Hash)],
vec![parse_quote!(#[some_other_attribute])],
false,
);

let settings = TypeGeneratorSettings {
Expand Down
25 changes: 14 additions & 11 deletions typegen/src/typegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{Derives, TypegenError};
use self::{
ir::module_ir::ModuleIR,
ir::type_ir::{CompositeFieldIR, CompositeIR, CompositeIRKind, EnumIR, TypeIR, TypeIRKind},
settings::TypeGeneratorSettings,
settings::{derives::FlatDerivesRegistry, TypeGeneratorSettings},
type_params::TypeParameters,
type_path::TypeParameter,
};
Expand Down Expand Up @@ -45,6 +45,12 @@ impl<'a> TypeGenerator<'a> {

/// Generate a module containing all types defined in the supplied type registry.
pub fn generate_types_mod(&self) -> Result<ModuleIR, TypegenError> {
let flat_derives_registry = self
.settings
.derives
.clone()
.flatten_recursive_derives(self.type_registry)?;

let mut root_mod = ModuleIR::new(
self.settings.types_mod_ident.clone(),
self.settings.types_mod_ident.clone(),
Expand All @@ -65,7 +71,7 @@ impl<'a> TypeGenerator<'a> {
}

// if the type is not a builtin type, insert it into the respective module
if let Some(type_ir) = self.create_type_ir(&ty.ty)? {
if let Some(type_ir) = self.create_type_ir(&ty.ty, &flat_derives_registry)? {
// Create the module this type should go into
let innermost_module = root_mod.get_or_insert_submodule(namespace);
innermost_module.types.insert(path.clone(), type_ir);
Expand All @@ -75,7 +81,11 @@ impl<'a> TypeGenerator<'a> {
Ok(root_mod)
}

pub fn create_type_ir(&self, ty: &Type<PortableForm>) -> Result<Option<TypeIR>, TypegenError> {
pub fn create_type_ir(
&self,
ty: &Type<PortableForm>,
flat_derives_registry: &FlatDerivesRegistry,
) -> Result<Option<TypeIR>, TypegenError> {
// if the type is some builtin, early return, we are only interested in generating structs and enums.
if !matches!(ty.type_def, TypeDef::Composite(_) | TypeDef::Variant(_)) {
return Ok(None);
Expand Down Expand Up @@ -124,7 +134,7 @@ impl<'a> TypeGenerator<'a> {
_ => unreachable!("Other variants early return before. qed."),
};

let mut derives = self.type_derives(ty)?;
let mut derives = flat_derives_registry.resolve_derives_for_type(ty)?;
if could_derive_as_compact {
self.add_as_compact_derive(&mut derives);
}
Expand Down Expand Up @@ -237,13 +247,6 @@ impl<'a> TypeGenerator<'a> {
self.settings.derives.default_derives()
}

pub fn type_derives(&self, ty: &Type<PortableForm>) -> Result<Derives, TypegenError> {
let joined_path = ty.path.segments.join("::");
let ty_path: syn::TypePath = syn::parse_str(&joined_path)?;
let derives = self.settings.derives.resolve(&ty_path);
Ok(derives)
}

/// Adds a AsCompact derive, if a path to AsCompact trait/derive macro set in settings.
fn add_as_compact_derive(&self, derives: &mut Derives) {
if let Some(compact_as_type_path) = &self.settings.compact_as_type_path {
Expand Down
168 changes: 155 additions & 13 deletions typegen/src/typegen/settings/derives.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::{HashMap, HashSet};

use quote::ToTokens;
use scale_info::{form::PortableForm, PortableRegistry, Type};

use crate::{utils::syn_type_path, TypegenError};

/// A struct containing the derives that we'll be applying to types;
/// a combination of some common derives for all types, plus type
Expand All @@ -9,6 +12,7 @@ use quote::ToTokens;
pub struct DerivesRegistry {
default_derives: Derives,
specific_type_derives: HashMap<syn::TypePath, Derives>,
recursive_type_derives: HashMap<syn::TypePath, Derives>,
}

impl DerivesRegistry {
Expand All @@ -33,8 +37,13 @@ impl DerivesRegistry {
ty: syn::TypePath,
derives: impl IntoIterator<Item = syn::Path>,
attributes: impl IntoIterator<Item = syn::Attribute>,
recursive: bool,
) {
let type_derives = self.specific_type_derives.entry(ty).or_default();
let type_derives = if recursive {
self.recursive_type_derives.entry(ty).or_default()
} else {
self.specific_type_derives.entry(ty).or_default()
};
type_derives.derives.extend(derives);
type_derives.attributes.extend(attributes);
}
Expand All @@ -43,18 +52,6 @@ impl DerivesRegistry {
pub fn default_derives(&self) -> &Derives {
&self.default_derives
}

/// Resolve the derives for a generated type. Includes:
/// - The default derives for all types e.g. `scale::Encode, scale::Decode`
/// - Any user-defined derives for all types via `generated_type_derives`
/// - Any user-defined derives for this specific type
pub fn resolve(&self, ty: &syn::TypePath) -> Derives {
let mut resolved_derives = self.default_derives.clone();
if let Some(specific) = self.specific_type_derives.get(ty) {
resolved_derives.extend_from(specific.clone());
}
resolved_derives
}
}

/// A struct storing the set of derives and derive attributes that we'll apply
Expand Down Expand Up @@ -136,3 +133,148 @@ impl ToTokens for Derives {
}
}
}

/// This is like a DerivesRegistry, but the recursive type derives have been flattened out into the specific_type_derives.
#[derive(Debug, Clone, Default)]
pub struct FlatDerivesRegistry {
default_derives: Derives,
specific_type_derives: HashMap<syn::TypePath, Derives>,
}

impl FlatDerivesRegistry {
/// Resolve the derives for a specific type.
pub fn resolve(&self, ty: &syn::TypePath) -> Derives {
let mut resolved_derives = self.default_derives.clone();
if let Some(specific) = self.specific_type_derives.get(ty) {
resolved_derives.extend_from(specific.clone());
}
resolved_derives
}

pub fn resolve_derives_for_type(
&self,
ty: &Type<PortableForm>,
) -> Result<Derives, TypegenError> {
Ok(self.resolve(&syn_type_path(ty)?))
}
}

impl DerivesRegistry {
pub fn flatten_recursive_derives(
self,
types: &PortableRegistry,
) -> Result<FlatDerivesRegistry, TypegenError> {
let DerivesRegistry {
default_derives,
mut specific_type_derives,
mut recursive_type_derives,
} = self;

if recursive_type_derives.is_empty() {
return Ok(FlatDerivesRegistry {
default_derives,
specific_type_derives,
});
}

// Build a mapping of type ids to syn paths for all types in the registry:
let mut syn_path_for_id: HashMap<u32, syn::TypePath> = types
.types
.iter()
.map(|t| {
let path = syn_type_path(&t.ty)?;
Ok((t.id, path))
})
.collect::<Result<_, TypegenError>>()?;

// Create an empty map of derives that we are about to fill:
let mut add_derives_for_id: HashMap<u32, Derives> = HashMap::new();

// Check for each type in the registry if it is the top level of
for ty in types.types.iter() {
let path = syn_path_for_id.get(&ty.id).expect("inserted above; qed;");
let Some(recursive_derives) = recursive_type_derives.remove(&path) else {
continue;
};
// The collected_type_ids contain the id of the type itself and all ids of its fields:
let mut collected_type_ids: HashSet<u32> = HashSet::new();
collect_type_ids(ty.id, types, &mut collected_type_ids);

// We collect the derives for each type id in the add_derives_for_id HashMap.
for id in collected_type_ids {
add_derives_for_id
.entry(id)
.or_default()
.extend_from(recursive_derives.clone());
}
}

// Merge all the recursively obtained derives with the existing derives for the types.
for (id, derived_to_add) in add_derives_for_id {
let path = syn_path_for_id
.remove(&id)
.expect("syn_path_for_id contains all type ids; qed;");
specific_type_derives
.entry(path)
.or_default()
.extend_from(derived_to_add);
}

Ok(FlatDerivesRegistry {
default_derives,
specific_type_derives,
})
}
}

fn collect_type_ids(id: u32, types: &PortableRegistry, collected_types: &mut HashSet<u32>) {
// Recursion protection:
if collected_types.contains(&id) {
return;
}

// Add the type id itself as well:
collected_types.insert(id);
let ty = types
.resolve(id)
.expect("Should contain this id, if Registry not corrupted");

// Collect the types that are passed as type params (Question/Note: Is this necessary? Maybe not...)
for param in ty.type_params.iter() {
if let Some(id) = param.ty.map(|e| e.id) {
collect_type_ids(id, types, collected_types);
}
}

// Collect ids depending on the types structure:
match &ty.type_def {
scale_info::TypeDef::Composite(def) => {
for f in def.fields.iter() {
collect_type_ids(f.ty.id, types, collected_types);
}
}
scale_info::TypeDef::Variant(def) => {
for v in def.variants.iter() {
for f in v.fields.iter() {
collect_type_ids(f.ty.id, types, collected_types);
}
}
}
scale_info::TypeDef::Sequence(def) => {
collect_type_ids(def.type_param.id, types, collected_types);
}
scale_info::TypeDef::Array(def) => {
collect_type_ids(def.type_param.id, types, collected_types);
}
scale_info::TypeDef::Tuple(def) => {
for f in def.fields.iter() {
collect_type_ids(f.id, types, collected_types);
}
}
scale_info::TypeDef::Primitive(_) => {}
scale_info::TypeDef::Compact(def) => {
collect_type_ids(def.type_param.id, types, collected_types);
}
scale_info::TypeDef::BitSequence(_) => {}
}
}
1 change: 0 additions & 1 deletion typegen/src/typegen/settings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use derives::DerivesRegistry;

use proc_macro2::Ident;
use substitutes::TypeSubstitutes;
use syn::parse_quote;
Expand Down
8 changes: 8 additions & 0 deletions typegen/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ use scale_info::{form::PortableForm, Field, PortableRegistry, Type, TypeDef, Typ
use smallvec::{smallvec, SmallVec};
use std::collections::HashMap;

use crate::TypegenError;

pub fn syn_type_path(ty: &Type<PortableForm>) -> Result<syn::TypePath, TypegenError> {
let joined_path = ty.path.segments.join("::");
let ty_path: syn::TypePath = syn::parse_str(&joined_path)?;
Ok(ty_path)
}

pub fn ensure_unique_type_paths(types: &mut PortableRegistry) {
let mut types_with_same_type_path = HashMap::<&[String], SmallVec<[u32; 2]>>::new();

Expand Down

0 comments on commit ca78ad5

Please sign in to comment.