Skip to content

Commit

Permalink
Custom implementation of import sorting.
Browse files Browse the repository at this point in the history
This tries to implement some of the rule of import sorting as requested
in #13 using custom logic, This was originally fixed in #263 using isort,
but reverted and  re-requested as of #287
  • Loading branch information
Carreau committed Apr 9, 2024
1 parent 82a3fe6 commit bfd1c00
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 100 deletions.
45 changes: 32 additions & 13 deletions bin/tidy-imports
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ Only top-level import statements are touched.
# License: MIT http://opensource.org/licenses/MIT


from __future__ import print_function

import os

from pyflyby._cmdline import hfmt, parse_args, process_actions
from pyflyby._imports2s import (canonicalize_imports,
fix_unused_and_missing_imports,
replace_star_imports,
transform_imports, sort_imports)
from pyflyby._import_sorting import sort_imports
from pyflyby._imports2s import (canonicalize_imports,
fix_unused_and_missing_imports,
replace_star_imports,
transform_imports)
from pyflyby._parse import PythonBlock

import toml
TOML_AVAIL = True
Expand Down Expand Up @@ -104,6 +108,10 @@ def _addopts(parser):
help=hfmt('''
(Default) Replace imports with canonical
equivalent imports, according to database.'''))
parser.add_option('--experimental-sort-imports',
default=False, action='store_true',
help=hfmt('''
experimental import sorting'''))
parser.add_option('--no-canonicalize', dest='canonicalize',
default=True, action='store_false',
help=hfmt('''
Expand All @@ -128,7 +136,7 @@ def _addopts(parser):
help=hfmt('''
Equivalent to --no-add-missing
--no-add-mandatory.'''))
def main():
def main() -> None:

config_text = _get_pyproj_toml_config()
if config_text:
Expand All @@ -139,29 +147,40 @@ def main():
def _add_opts_and_defaults(parser):
_addopts(parser)
parser.set_defaults(**default_config)

options, args = parse_args(
_add_opts_and_defaults, import_format_params=True, modify_action_params=True, defaults=default_config)

def modify(x):
def modify(block:PythonBlock) -> PythonBlock:
if options.transformations:
x = transform_imports(x, options.transformations,
block = transform_imports(block, options.transformations,
params=options.params)
if options.replace_star_imports:
x = replace_star_imports(x, params=options.params)
x = fix_unused_and_missing_imports(
x, params=options.params,
block = replace_star_imports(block, params=options.params)
block = fix_unused_and_missing_imports(
block, params=options.params,
add_missing=options.add_missing,
remove_unused=options.remove_unused,
add_mandatory=options.add_mandatory,
)
# TODO: disable sorting until we figure out #287
# https://github.com/deshaw/pyflyby/issues/287
# sorted_imports = sort_imports(x)
sorted_imports = x
# here we get a (single?) PythonBlock, we can access each statement with
# >>> block.statements
# and each statement can have a:
# is_import
# or
# is_comment or blank

# TODO: we do Python(str(...)) in order to unparse-reparse and get proper ast node numbers.
if options.experimental_sort_imports:
sorted_imports = PythonBlock(str(sort_imports(block)))
else:
sorted_imports = block
if options.canonicalize:
cannonical_imports = canonicalize_imports(sorted_imports, params=options.params)
else:
cannonical_imports = sort_imports
cannonical_imports = sorted_imports
return cannonical_imports
process_actions(args, options.actions, modify)

Expand Down
165 changes: 165 additions & 0 deletions lib/python/pyflyby/_import_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
This module contain utility functions to sort imports in a Python Block.
"""
from __future__ import annotations

from collections import Counter
from dataclasses import dataclass
from itertools import groupby
from pyflyby._importstmt import ImportStatement, PythonStatement
from pyflyby._parse import PythonBlock
from typing import List, Optional, Tuple, Union


# @dataclass
# class ImportSection:
# """
# This represent an Import
# """
# sections: List[ImportGroup]


@dataclass
class ImportGroup:
"""
Typically at the top of a file, the first import will be part
of an ImportGroup, and subsequent imports
will be other import groups which.
Import sorting will affect only the imports in the same group,
as we want import sorting to be indempotent. If import sorting
were migrating imports between groups, then imports could be
moved to sections were comments would nto be relevant.
"""

imports: List[ImportStatement]

@classmethod
def from_statements(cls, statements: List[PythonStatement]) -> ImportGroup:
return ImportGroup([ImportStatement(s) for s in statements])

def sorted(self) -> ImportGroup:
"""
return an ImportGroup with import sorted lexicographically.
"""
return ImportGroup(sorted(self.imports, key=lambda x: x._cmp()))

def sorted_subgroups(self) -> List[Tuple[bool, List[ImportStatement]]]:
"""
Return a list of subgroup keyed by module and whether they are the sole import
from this module.
From issue #13, we will want to sort import from the same package together when
there is more than one import and separat with a blank line
We also group all imports from a package that appear only once together.
For this we need both to know if the import is from a single import (the boolean).
Returns
-------
bool:
wether from a single import
List[ImportStatement]
The actual import.
"""
c = Counter(imp.module[0] for imp in self.imports)
return [
(c[k] > 1, list(v)) for k, v in groupby(self.imports, lambda x: x.module[0])
]


def split_import_groups(
statements: Tuple[PythonStatement],
) -> List[Union[ImportGroup, PythonStatement]]:
"""
Given a list of statements split into import groups.
One of the question is how to treat split with comments.
- Do blank lines after groups with comments start a new block
- Does the comment line create a complete new block.
In particular because we want this to be indempotent,
we can't move imports between groups.
"""

# these are the import groups we'll
groups: List[PythonStatement | ImportGroup] = []

current_group: List[PythonStatement] = []
statemt_iterator = iter(statements)
for statement in statemt_iterator:
if statement.is_blank and current_group:
pass
# currently do nothing with whitespace while in import groups.
elif statement.is_import:
# push on top of current_comment.
current_group.append(statement)
else:
if current_group:
groups.append(ImportGroup.from_statements(current_group))
current_group = []
groups.append(statement)

# we should break of and populate rest
# We can't do anything if we encounter any non-import statement,
# as we do no know if it can be a conditional.
# technically I guess we coudl find another import block, and reorder these.
if current_group:
groups.append(ImportGroup.from_statements(current_group))

# this is an iterator, not an iterable, we exaust it to reify the rest.

# first group may be empty if the first line is a comment.
# We filter and sort relevant statements.
groups = [g for g in groups if groups]
sorted_groups: List[PythonStatement | ImportGroup] = [
g.sorted() if isinstance(g, ImportGroup) else g for g in groups
]
return sorted_groups


def regroup(groups: List[ImportGroup | PythonStatement]) -> PythonBlock:
"""
given import groups and list of statement, return an Python block with sorted import
"""
res: str = ""
in_single = False
for group in groups:
if isinstance(group, ImportGroup):
# the subgroup here will be responsible for groups reordering.
for mult, subgroup in group.sorted_subgroups():
if mult:
if in_single:
res += "\n"
if not res.endswith("\n"):
res += "\n"
for x in subgroup:
res += x.pretty_print(import_column=30).rstrip() + "\n"
if not res.endswith("\n\n"):
res += "\n"
in_single = False
else:
assert len(subgroup) == 1
in_single = True
for sub in subgroup:
res += str(sub).rstrip() + "\n"

if not res.endswith("\n\n"):
res += "\n"
else:
if in_single and not res.endswith("\n\n"):
res += "\n"
in_single = False
res += str(PythonBlock(group))

return PythonBlock.concatenate([PythonBlock(res)])


def sort_imports(block: PythonBlock) -> PythonBlock:
assert isinstance(block, PythonBlock)
# we ignore below that block.statement can be a List[Unknown]
gs = split_import_groups(block.statements) # type: ignore
# TODO: math the other fileds like filename....
return regroup(gs)
93 changes: 18 additions & 75 deletions lib/python/pyflyby/_imports2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pyflyby._util import ImportPathCtx, Inf, NullCtx, memoize
import re

from typing import Union


class SourceToSourceTransformationBase(object):

Expand All @@ -40,7 +42,7 @@ def preprocess(self):
def pretty_print(self, params=None):
raise NotImplementedError

def output(self, params=None):
def output(self, params=None) -> PythonBlock:
"""
Pretty-print and return as a `PythonBlock`.
Expand Down Expand Up @@ -293,12 +295,14 @@ def ImportPathForRelativeImportsCtx(codeblock):
return ImportPathCtx(str(codeblock.filename.dir))


def fix_unused_and_missing_imports(codeblock,
add_missing=True,
remove_unused="AUTOMATIC",
add_mandatory=True,
db=None,
params=None):
def fix_unused_and_missing_imports(
codeblock: Union[PythonBlock, str],
add_missing=True,
remove_unused="AUTOMATIC",
add_mandatory=True,
db=None,
params=None,
) -> PythonBlock:
r"""
Check for unused and missing imports, and fix them automatically.
Expand All @@ -325,9 +329,9 @@ def fix_unused_and_missing_imports(codeblock,
:rtype:
`PythonBlock`
"""
codeblock = PythonBlock(codeblock)
_codeblock:PythonBlock = PythonBlock(codeblock)
if remove_unused == "AUTOMATIC":
fn = codeblock.filename
fn = _codeblock.filename
remove_unused = not (fn and
(fn.base == "__init__.py"
or ".pyflyby" in str(fn).split("/")))
Expand All @@ -336,18 +340,18 @@ def fix_unused_and_missing_imports(codeblock,
else:
raise ValueError("Invalid remove_unused=%r" % (remove_unused,))
params = ImportFormatParams(params)
db = ImportDB.interpret_arg(db, target_filename=codeblock.filename)
db = ImportDB.interpret_arg(db, target_filename=_codeblock.filename)
# Do a first pass reformatting the imports to get rid of repeated or
# shadowed imports, e.g. L1 here:
# import foo # L1
# import foo # L2
# foo # L3
codeblock = reformat_import_statements(codeblock, params=params)
_codeblock = reformat_import_statements(_codeblock, params=params)

filename = codeblock.filename
transformer = SourceToSourceFileImportsTransformation(codeblock)
filename = _codeblock.filename
transformer = SourceToSourceFileImportsTransformation(_codeblock)
missing_imports, unused_imports = scan_for_import_issues(
codeblock, find_unused_imports=remove_unused, parse_docstrings=True)
_codeblock, find_unused_imports=remove_unused, parse_docstrings=True)
logger.debug("missing_imports = %r", missing_imports)
logger.debug("unused_imports = %r", unused_imports)
if remove_unused and unused_imports:
Expand Down Expand Up @@ -537,67 +541,6 @@ def replace_star_imports(codeblock, params=None):
return transformer.output(params=params)


def sort_imports(codeblock):
"""
Sort imports for better grouping.
:param codeblock:
:return: codeblock
"""
import isort
sorted_imports = isort.code(
str(codeblock),
# To sort all the import in lexicographic order
force_sort_within_sections=True,
# This is done below
lines_between_sections=0,
lines_after_imports=1
)
# Step 1: Split the input string into a list of lines
lines = sorted_imports.split('\n')

# Step 2: Identify groups of imports and keep track of their line numbers
pkg_lines = defaultdict(list)
line_pkg_dict = {}
for i, line in enumerate(lines):
match = re.match(r'(from (\w+)|import (\w+))', line)
if match:
current_pkg = match.groups()[1:3]
current_pkg = current_pkg[0] if current_pkg[0] is not None else current_pkg[1]
pkg_lines[current_pkg].append(i)
line_pkg_dict[i] = current_pkg

# Step 3: Create the output list of lines with blank lines around groups with more than one import
output_lines = []

def next_line(index):
if index + 1 < len(lines):
return lines[index + 1]
else:
return ""

for i, line in enumerate(lines):
if (
i > 0
and line_pkg_dict.get(i) != line_pkg_dict.get(i - 1)
and len(pkg_lines[line_pkg_dict.get(i)]) > 1
and next_line(i).startswith(("import", "from"))
and output_lines[-1] != ''
):
output_lines.append("")
output_lines.append(line)
if (
i < len(lines) - 1
and line_pkg_dict.get(i) != line_pkg_dict.get(i + 1)
and len(pkg_lines[line_pkg_dict.get(i)]) > 1
and next_line(i).startswith(("import", "from"))
):
output_lines.append("")

# Step 4: Join the lines to create the output string
sorted_output_str = '\n'.join(output_lines)
return PythonBlock(sorted_output_str)


def transform_imports(codeblock, transformations, params=None):
"""
Transform imports as specified by ``transformations``.
Expand Down
Loading

0 comments on commit bfd1c00

Please sign in to comment.