From bfdbc9724f9dcbe8cca6d4b29260705865530ea3 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Tue, 12 Mar 2024 15:33:12 +0100 Subject: [PATCH] Custom implementation of import sorting. This tries to implement some of the rule of import sorting as requested in #13 using custom logic, This was originally fixed in #263 using isort, but reverted and re-requested as of #287 --- bin/tidy-imports | 51 +++++--- lib/python/pyflyby/_import_sorting.py | 165 ++++++++++++++++++++++++++ lib/python/pyflyby/_imports2s.py | 96 ++++----------- lib/python/pyflyby/_importstmt.py | 33 ++++++ tests/tests_sorts.py | 77 ++++++++++-- 5 files changed, 320 insertions(+), 102 deletions(-) create mode 100644 lib/python/pyflyby/_import_sorting.py diff --git a/bin/tidy-imports b/bin/tidy-imports index c07106e5..da2671f1 100755 --- a/bin/tidy-imports +++ b/bin/tidy-imports @@ -21,13 +21,17 @@ Only top-level import statements are touched. # License: MIT http://opensource.org/licenses/MIT +from __future__ import print_function + import os from pyflyby._cmdline import hfmt, parse_args, process_actions -from pyflyby._imports2s import (canonicalize_imports, - fix_unused_and_missing_imports, - replace_star_imports, - transform_imports) +from pyflyby._import_sorting import sort_imports +from pyflyby._imports2s import (canonicalize_imports, + fix_unused_and_missing_imports, + replace_star_imports, + transform_imports) +from pyflyby._parse import PythonBlock import toml TOML_AVAIL = True @@ -104,6 +108,10 @@ def _addopts(parser): help=hfmt(''' (Default) Replace imports with canonical equivalent imports, according to database.''')) + parser.add_option('--experimental-sort-imports', + default=False, action='store_true', + help=hfmt(''' + experimental import sorting''')) parser.add_option('--no-canonicalize', dest='canonicalize', default=True, action='store_false', help=hfmt(''' @@ -128,7 +136,7 @@ def _addopts(parser): help=hfmt(''' Equivalent to --no-add-missing --no-add-mandatory.''')) -def main(): +def main() -> None: config_text = _get_pyproj_toml_config() if config_text: @@ -139,28 +147,43 @@ def main(): def _add_opts_and_defaults(parser): _addopts(parser) parser.set_defaults(**default_config) + options, args = parse_args( _add_opts_and_defaults, import_format_params=True, modify_action_params=True, defaults=default_config) - def modify(x): + def modify(block:PythonBlock) -> PythonBlock: if options.transformations: - x = transform_imports(x, options.transformations, + block = transform_imports(block, options.transformations, params=options.params) if options.replace_star_imports: - x = replace_star_imports(x, params=options.params) - x = fix_unused_and_missing_imports( - x, params=options.params, + block = replace_star_imports(block, params=options.params) + block = fix_unused_and_missing_imports( + block, params=options.params, add_missing=options.add_missing, remove_unused=options.remove_unused, add_mandatory=options.add_mandatory, ) # TODO: disable sorting until we figure out #287 # https://github.com/deshaw/pyflyby/issues/287 - # from pyflyby._imports2s import sort_imports - # x = sort_imports(x) + + # here we get a (single?) PythonBlock, we can access each statement with + # >>> block.statements + # and each statement can have a: + # is_import + # or + # is_comment or blank + + # TODO: we do Python(str(...)) in order to unparse-reparse and get proper ast node numbers. + if options.experimental_sort_imports: + sorted_imports = PythonBlock(str(sort_imports(block))) + else: + sorted_imports = block if options.canonicalize: - x = canonicalize_imports(x, params=options.params) - return x + cannonical_imports = canonicalize_imports(sorted_imports, params=options.params) + else: + cannonical_imports = sorted_imports + return cannonical_imports + process_actions(args, options.actions, modify) diff --git a/lib/python/pyflyby/_import_sorting.py b/lib/python/pyflyby/_import_sorting.py new file mode 100644 index 00000000..89a656a1 --- /dev/null +++ b/lib/python/pyflyby/_import_sorting.py @@ -0,0 +1,165 @@ +""" +This module contain utility functions to sort imports in a Python Block. +""" +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass +from itertools import groupby +from pyflyby._importstmt import ImportStatement, PythonStatement +from pyflyby._parse import PythonBlock +from typing import List, Optional, Tuple, Union + + +# @dataclass +# class ImportSection: +# """ +# This represent an Import +# """ +# sections: List[ImportGroup] + + +@dataclass +class ImportGroup: + """ + Typically at the top of a file, the first import will be part + of an ImportGroup, and subsequent imports + will be other import groups which. + + Import sorting will affect only the imports in the same group, + as we want import sorting to be indempotent. If import sorting + were migrating imports between groups, then imports could be + moved to sections were comments would nto be relevant. + """ + + imports: List[ImportStatement] + + @classmethod + def from_statements(cls, statements: List[PythonStatement]) -> ImportGroup: + return ImportGroup([ImportStatement(s) for s in statements]) + + def sorted(self) -> ImportGroup: + """ + return an ImportGroup with import sorted lexicographically. + """ + return ImportGroup(sorted(self.imports, key=lambda x: x._cmp())) + + def sorted_subgroups(self) -> List[Tuple[bool, List[ImportStatement]]]: + """ + Return a list of subgroup keyed by module and whether they are the sole import + from this module. + + From issue #13, we will want to sort import from the same package together when + there is more than one import and separat with a blank line + + We also group all imports from a package that appear only once together. + + For this we need both to know if the import is from a single import (the boolean). + + Returns + ------- + bool: + wether from a single import + List[ImportStatement] + The actual import. + """ + c = Counter(imp.module[0] for imp in self.imports) + return [ + (c[k] > 1, list(v)) for k, v in groupby(self.imports, lambda x: x.module[0]) + ] + + +def split_import_groups( + statements: Tuple[PythonStatement], +) -> List[Union[ImportGroup, PythonStatement]]: + """ + Given a list of statements split into import groups. + + One of the question is how to treat split with comments. + + - Do blank lines after groups with comments start a new block + - Does the comment line create a complete new block. + + In particular because we want this to be indempotent, + we can't move imports between groups. + """ + + # these are the import groups we'll + groups: List[PythonStatement | ImportGroup] = [] + + current_group: List[PythonStatement] = [] + statemt_iterator = iter(statements) + for statement in statemt_iterator: + if statement.is_blank and current_group: + pass + # currently do nothing with whitespace while in import groups. + elif statement.is_import: + # push on top of current_comment. + current_group.append(statement) + else: + if current_group: + groups.append(ImportGroup.from_statements(current_group)) + current_group = [] + groups.append(statement) + + # we should break of and populate rest + # We can't do anything if we encounter any non-import statement, + # as we do no know if it can be a conditional. + # technically I guess we coudl find another import block, and reorder these. + if current_group: + groups.append(ImportGroup.from_statements(current_group)) + + # this is an iterator, not an iterable, we exaust it to reify the rest. + + # first group may be empty if the first line is a comment. + # We filter and sort relevant statements. + groups = [g for g in groups if groups] + sorted_groups: List[PythonStatement | ImportGroup] = [ + g.sorted() if isinstance(g, ImportGroup) else g for g in groups + ] + return sorted_groups + + +def regroup(groups: List[ImportGroup | PythonStatement]) -> PythonBlock: + """ + given import groups and list of statement, return an Python block with sorted import + """ + res: str = "" + in_single = False + for group in groups: + if isinstance(group, ImportGroup): + # the subgroup here will be responsible for groups reordering. + for mult, subgroup in group.sorted_subgroups(): + if mult: + if in_single: + res += "\n" + if not res.endswith("\n"): + res += "\n" + for x in subgroup: + res += x.pretty_print(import_column=30).rstrip() + "\n" + if not res.endswith("\n\n"): + res += "\n" + in_single = False + else: + assert len(subgroup) == 1 + in_single = True + for sub in subgroup: + res += str(sub).rstrip() + "\n" + + if not res.endswith("\n\n"): + res += "\n" + else: + if in_single and not res.endswith("\n\n"): + res += "\n" + in_single = False + res += str(PythonBlock(group)) + + return PythonBlock.concatenate([PythonBlock(res)]) + + +def sort_imports(block: PythonBlock) -> PythonBlock: + assert isinstance(block, PythonBlock) + # we ignore below that block.statement can be a List[Unknown] + gs = split_import_groups(block.statements) # type: ignore + # TODO: math the other fileds like filename.... + return regroup(gs) diff --git a/lib/python/pyflyby/_imports2s.py b/lib/python/pyflyby/_imports2s.py index 1d691f0e..29b11d0e 100644 --- a/lib/python/pyflyby/_imports2s.py +++ b/lib/python/pyflyby/_imports2s.py @@ -14,6 +14,8 @@ from pyflyby._util import ImportPathCtx, Inf, NullCtx, memoize import re +from typing import Union + class SourceToSourceTransformationBase(object): @@ -48,7 +50,7 @@ def preprocess(self): def pretty_print(self, params=None): raise NotImplementedError - def output(self, params=None): + def output(self, params=None) -> PythonBlock: """ Pretty-print and return as a `PythonBlock`. @@ -305,12 +307,14 @@ def ImportPathForRelativeImportsCtx(codeblock): return ImportPathCtx(str(codeblock.filename.dir)) -def fix_unused_and_missing_imports(codeblock, - add_missing=True, - remove_unused="AUTOMATIC", - add_mandatory=True, - db=None, - params=None): +def fix_unused_and_missing_imports( + codeblock: Union[PythonBlock, str], + add_missing=True, + remove_unused="AUTOMATIC", + add_mandatory=True, + db=None, + params=None, +) -> PythonBlock: r""" Check for unused and missing imports, and fix them automatically. @@ -338,11 +342,11 @@ def fix_unused_and_missing_imports(codeblock, `PythonBlock` """ if isinstance(codeblock, Filename): - codeblock = PythonBlock(filename=codeblock) + _codeblock = PythonBlock(codeblock) if not isinstance(codeblock, PythonBlock): - codeblock = PythonBlock(codeblock) + _codeblock = PythonBlock(codeblock) if remove_unused == "AUTOMATIC": - fn = codeblock.filename + fn = _codeblock.filename remove_unused = not (fn and (fn.base == "__init__.py" or ".pyflyby" in str(fn).split("/"))) @@ -351,18 +355,19 @@ def fix_unused_and_missing_imports(codeblock, else: raise ValueError("Invalid remove_unused=%r" % (remove_unused,)) params = ImportFormatParams(params) - db = ImportDB.interpret_arg(db, target_filename=codeblock.filename) + db = ImportDB.interpret_arg(db, target_filename=_codeblock.filename) # Do a first pass reformatting the imports to get rid of repeated or # shadowed imports, e.g. L1 here: # import foo # L1 # import foo # L2 # foo # L3 - codeblock = reformat_import_statements(codeblock, params=params) + _codeblock = reformat_import_statements(_codeblock, params=params) - filename = codeblock.filename - transformer = SourceToSourceFileImportsTransformation(codeblock) + filename = _codeblock.filename + transformer = SourceToSourceFileImportsTransformation(_codeblock) missing_imports, unused_imports = scan_for_import_issues( - codeblock, find_unused_imports=remove_unused, parse_docstrings=True) + _codeblock, find_unused_imports=remove_unused, parse_docstrings=True + ) logger.debug("missing_imports = %r", missing_imports) logger.debug("unused_imports = %r", unused_imports) if remove_unused and unused_imports: @@ -554,67 +559,6 @@ def replace_star_imports(codeblock, params=None): return transformer.output(params=params) -def sort_imports(codeblock): - """ - Sort imports for better grouping. - :param codeblock: - :return: codeblock - """ - import isort - sorted_imports = isort.code( - str(codeblock), - # To sort all the import in lexicographic order - force_sort_within_sections=True, - # This is done below - lines_between_sections=0, - lines_after_imports=1 - ) - # Step 1: Split the input string into a list of lines - lines = sorted_imports.split('\n') - - # Step 2: Identify groups of imports and keep track of their line numbers - pkg_lines = defaultdict(list) - line_pkg_dict = {} - for i, line in enumerate(lines): - match = re.match(r'(from (\w+)|import (\w+))', line) - if match: - current_pkg = match.groups()[1:3] - current_pkg = current_pkg[0] if current_pkg[0] is not None else current_pkg[1] - pkg_lines[current_pkg].append(i) - line_pkg_dict[i] = current_pkg - - # Step 3: Create the output list of lines with blank lines around groups with more than one import - output_lines = [] - - def next_line(index): - if index + 1 < len(lines): - return lines[index + 1] - else: - return "" - - for i, line in enumerate(lines): - if ( - i > 0 - and line_pkg_dict.get(i) != line_pkg_dict.get(i - 1) - and len(pkg_lines[line_pkg_dict.get(i)]) > 1 - and next_line(i).startswith(("import", "from")) - and output_lines[-1] != '' - ): - output_lines.append("") - output_lines.append(line) - if ( - i < len(lines) - 1 - and line_pkg_dict.get(i) != line_pkg_dict.get(i + 1) - and len(pkg_lines[line_pkg_dict.get(i)]) > 1 - and next_line(i).startswith(("import", "from")) - ): - output_lines.append("") - - # Step 4: Join the lines to create the output string - sorted_output_str = '\n'.join(output_lines) - return PythonBlock(sorted_output_str) - - def transform_imports(codeblock, transformations, params=None): """ Transform imports as specified by ``transformations``. diff --git a/lib/python/pyflyby/_importstmt.py b/lib/python/pyflyby/_importstmt.py index d2040f41..0f83357b 100644 --- a/lib/python/pyflyby/_importstmt.py +++ b/lib/python/pyflyby/_importstmt.py @@ -492,6 +492,39 @@ def imports(self): Import.from_split((self.fromname, alias[0], alias[1])) for alias in self.aliases) + @property + def module(self) -> Tuple[str, ...]: + """ + + return the import module as a list of string (which would be joined by + dot in the original import form. + + This is useful for sorting purposes + + Note that this may contain some empty string in particular with relative + imports + """ + if self.fromname: + return tuple(self.fromname.split('.')) + + + assert len(self.aliases) == 1, self.aliases + + return tuple(self.aliases[0][0].split('.')) + + + def _cmp(self): + """ + Comparison function for sorting. + + We want to sort: + - by the root module + - whether it is an "import ... as", or "from ... import as" import + - then lexicographically + + """ + return (self.module[0], 0 if self.fromname is not None else 1, self.fromname) + @cached_attribute def flags(self): """ diff --git a/tests/tests_sorts.py b/tests/tests_sorts.py index 162fd52a..1dde4738 100644 --- a/tests/tests_sorts.py +++ b/tests/tests_sorts.py @@ -1,5 +1,6 @@ from pyflyby._parse import PythonBlock -from pyflyby._imports2s import sort_imports, fix_unused_and_missing_imports +from pyflyby._imports2s import fix_unused_and_missing_imports +from pyflyby._import_sorting import sort_imports from textwrap import dedent from lib.python.pyflyby._importstmt import ImportFormatParams import pytest @@ -12,6 +13,7 @@ import json import simplejson + from pkg2 import same, line from pkg1.mod1 import foo from pkg1.mod2 import bar from pkg2 import baz @@ -21,30 +23,59 @@ from pkg1.mod3 import quux from pkg2 import baar import zz - """) + """).strip() -# The logic requested in issue 13 whas to keep blocks, but order +# The logic requested in issue 13 was to keep blocks, but order # lexicographically. we do seem to be splitting stdlib from installed packages, # but not adding a blank line and then sort lexicographically if an package has -# several submodules importted we put it in a block, otherwise we keep +# several submodules imported we put it in a block, otherwise we keep # everything together. +# note that sorting does not concatenate same imports together, but cannonicalize +# in tidy import will. + expected1 = dedent(""" - import json - import os import external + import json import numpy + import os - from pkg1.mod1 import foo, foo2 - from pkg1.mod2 import bar - from pkg1.mod3 import quux + from pkg1.mod1 import foo + from pkg1.mod1 import foo2 + from pkg1.mod2 import bar + from pkg1.mod3 import quux + + from pkg2 import same, line + from pkg2 import baz + from pkg2 import baar - from pkg2 import baar, baz import simplejson import sympy import yy import zz - """) + """).strip()+'\n\n' + + +code2 = dedent( + """ + '''module docstring''' + import os + + pass + """ +).strip() + +expected2 = code2 + + +code3 = dedent(""" + import os + + if True: + "ok" + """).strip() + +expected3 = code3 # stable should not change stable_1 = dedent( @@ -60,7 +91,29 @@ ) -@pytest.mark.parametrize("code, expected", [(code1, expected1)]) +code4 = dedent( + """ +from __future__ import print_function + +from pyflyby._cmdline import filename_args, hfmt, parse_args +from pyflyby._importclns import ImportSet +from pyflyby._importdb import ImportDB + +import re +import sys + +def main(): + def addopts(parser): + pass""" +).strip() + +expected4 = code4 + + +@pytest.mark.parametrize( + "code, expected", + [(code1, expected1), (code2, expected2), (code3, expected3), (code4, expected4)], +) def test_sort_1(code, expected): assert str(sort_imports(PythonBlock(code))) == expected