Coverage for src/canonical_imports/_core.py: 99%
212 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-02-05 19:37 +0100
« prev ^ index » next coverage.py v7.4.0, created at 2024-02-05 19:37 +0100
1import ast
2import collections
3import copy
4import difflib
5from collections import defaultdict
6from dataclasses import dataclass
7from pathlib import Path
8from typing import Optional
10import asttokens
11import click
12import structlog
13from rich import print as rprint
14from rich.markdown import Markdown
16from ._utils import unparse
18log = structlog.get_logger()
21def is_private(name: str):
22 return len(name) >= 2 and name[0] == "_" and name[1] != "_"
25def is_module_private(module_name: str):
26 return any(is_private(name) for name in module_name.split("."))
29@dataclass
30class Import:
31 module: str
32 name: str
33 asname: Optional[str]
35 def is_private(self):
36 return is_private(self.name) or is_module_private(self.module)
38 @classmethod
39 def from_ast(cls, stmt: ast.ImportFrom, module: str):
40 for name in stmt.names:
41 module_parts = []
42 if stmt.level != 0:
43 module_parts = module.split(".")[: -(stmt.level)]
45 if stmt.module:
46 module_parts.append(stmt.module)
48 yield Import(".".join(module_parts), name.name, name.asname or name.name)
50 def is_inside(self, other):
51 return self.module.startswith(other.module + ".")
54class Module:
55 def __init__(self, path, import_fixer):
56 self.path = path
57 self.imports = []
58 self.import_fixer = import_fixer
59 p = path
60 m = [p.stem]
61 p = p.parent
63 package_folder = p
65 while (p / "__init__.py").exists():
66 package_folder = p
67 m.insert(0, p.name)
68 p = p.parent
70 self.module = ".".join(m)
71 self.module_parts = m
73 import_fixer.register(self, package_folder)
75 self.atok = None
77 try:
78 self.atok = asttokens.ASTTokens(self.path.read_text(), parse=True)
79 except SyntaxError:
80 log.error(f"could not parse {self.path}")
81 except UnicodeDecodeError:
82 log.error(f"could not decode {self.path}")
83 else:
84 assigned_names = {
85 name.id
86 for name in ast.walk(self.tree)
87 if isinstance(name, ast.Name) and isinstance(name.ctx, ast.Store)
88 }
90 imported_names = collections.Counter(
91 stmt.asname or stmt.name
92 for stmt in ast.walk(self.tree)
93 if isinstance(stmt, ast.alias)
94 )
96 double_imports = {
97 name for name, count in imported_names.items() if count >= 2
98 }
100 for stmt in self.tree.body:
101 if isinstance(stmt, ast.ImportFrom):
102 for i in Import.from_ast(stmt, self.module):
103 if (
104 i.asname not in assigned_names
105 and i.asname not in double_imports
106 ):
107 self.imports.append(i)
109 def is_init(self):
110 return self.path.name == "__init__.py"
112 def relative_to_me(self, name):
113 assert name[0] != "."
114 assert name != self.module
116 self_parts = list(self.module_parts)
117 name_parts = name.split(".")
118 if name_parts[0] != self_parts[0]: 118 ↛ 119line 118 didn't jump to line 119, because the condition on line 118 was never true
119 return {"level": 0, "module": name}
121 while name_parts and self_parts and name_parts[0] == self_parts[0]:
122 name_parts.pop(0)
123 self_parts.pop(0)
125 return {"level": len(self_parts), "module": ".".join(name_parts)}
127 @property
128 def tree(self):
129 return self.atok.tree
131 def __repr__(self):
132 return f"Module({self.path!r})" # pragma: no cover
134 def lookup(self, name) -> Optional[Import]:
135 for import_ in self.imports:
136 if import_.asname == name:
137 return import_
139 def change_set(self):
140 changes = []
142 if self.atok is None:
143 return ChangeSet(self.path, [])
145 for stmt in ast.walk(self.tree):
146 if isinstance(stmt, ast.ImportFrom):
147 new_imports = {} # name -> module
148 for first_import in Import.from_ast(stmt, self.module):
149 last_import = first_import
151 step = 0
153 all_imports = [first_import]
155 skip = False
157 while True:
158 step += 1
159 if step > 50:
160 skip = True
161 break
162 module = self.import_fixer.lookup_module(last_import.module)
163 if module is None:
164 break
165 new_import = module.lookup(last_import.name)
167 if new_import != None:
168 last_import = new_import
169 all_imports.append(new_import)
170 else:
171 break
173 if skip:
174 continue
176 for i in reversed(range(1, len(all_imports))):
177 if self.import_fixer.is_allowed(
178 self.module, all_imports[: i + 1]
179 ):
180 new_imports[first_import.asname] = all_imports[i]
181 break
183 m = ast.Module(body=[], type_ignores=[])
185 by_module = defaultdict(list)
186 for import_name, import_ in new_imports.items():
187 by_module[import_.module].append((import_name, import_))
189 for module, imports in by_module.items():
190 is_relative = stmt.level != 0
191 m.body.append(
192 ast.ImportFrom(
193 **(
194 self.relative_to_me(module)
195 if is_relative
196 else {"level": 0, "module": module}
197 ),
198 names=[
199 ast.alias(
200 name=import_.name,
201 asname=(
202 import_name
203 if import_name != import_.name
204 else None
205 ),
206 )
207 for import_name, import_ in imports
208 ],
209 )
210 )
212 new_stmt = copy.deepcopy(stmt)
214 new_stmt.names = [
215 n for n in stmt.names if (n.asname or n.name) not in new_imports
216 ]
218 if new_stmt.names:
219 m.body.append(new_stmt)
221 if new_imports:
222 indent = " " * stmt.first_token.start[1]
223 changes.append(
224 (
225 stmt.first_token.startpos,
226 stmt.last_token.endpos,
227 unparse(m).replace("\n", "\n" + indent),
228 )
229 )
231 return ChangeSet(self.path, changes)
234class ChangeSet:
235 def __init__(self, path, changes):
236 self.path = path
237 self.changes = changes
239 def __bool__(self):
240 return bool(self.changes)
242 def preview(self):
243 old_code = self.path.read_text()
244 new_code = asttokens.util.replace(old_code, self.changes)
246 diff = difflib.unified_diff(
247 old_code.splitlines(),
248 new_code.splitlines(),
249 )
250 return diff
252 def fix(self):
253 old_code = self.path.read_text()
254 self.path.write_text(asttokens.util.replace(old_code, self.changes))
257class ImportFixer:
258 def __init__(self, flags):
259 self.module_cache = {}
260 self.package_cache = {} # package_name -> folder
261 self.flags = flags
263 def is_allowed(self, module, import_chain):
264 if "public-private" in self.flags:
265 last_import = import_chain[-1]
266 if not is_module_private(module) and last_import.is_private():
267 return False
269 if "into-init" in self.flags:
270 if any(
271 self.is_init(init.module)
272 and all(imp.is_inside(init) for imp in import_chain[i + 1 :])
273 for i, init in enumerate(import_chain[:-1])
274 ):
275 return False
276 return True
278 def register(self, module, package_folder):
279 self.package_cache[module.module_parts[0]] = package_folder
281 self.module_cache[module.path] = module
283 def lookup_file(self, filename) -> Module:
284 if filename not in self.module_cache:
285 return Module(filename, self)
286 return self.module_cache[filename]
288 def lookup_module(self, module: str) -> Optional[Module]:
289 module_parts = module.split(".")
290 package_name = module_parts[0]
291 if package_name not in self.package_cache:
292 return None
294 f = self.package_cache[package_name].joinpath(*module_parts[1:])
295 for option in (f.with_suffix(".py"), f / "__init__.py"):
296 if option.exists():
297 return self.lookup_file(option)
298 else:
299 return
301 def is_init(self, module_name: str):
302 module = self.lookup_module(module_name)
303 return module is not None and module.is_init()
306@click.command()
307@click.option(
308 "--no",
309 help="Exclude specific imports",
310 multiple=True,
311 type=click.Choice(["public-private", "into-init"]),
312)
313@click.option("--write", "-w", is_flag=True, help="write changed imports")
314@click.argument("paths", nargs=-1, type=click.Path(exists=True))
315def main(no, paths, write):
316 """`canonical-imports` follows your imports and finds out where the things
317 you are importing are actually defined.
319 PATHS: python files or directories with should be scanned for python files
320 """
321 import_fixer = ImportFixer(set(no))
323 files = []
324 for file in paths:
325 file = Path(file)
326 if file.is_dir():
327 files += list(file.rglob("*.py"))
328 else:
329 files.append(file)
331 source_files = [Module(file, import_fixer) for file in files]
333 if write:
334 for file in source_files:
335 change_set = file.change_set()
336 if change_set:
337 print(f"fix: {file.path}")
338 change_set.fix()
339 else:
340 text = []
341 for file in source_files:
342 change_set = file.change_set()
343 if change_set:
344 text.append(f"## {file.path}")
345 text.append("```diff")
346 text += list(change_set.preview())[2:]
347 text.append("```")
348 text.append("---")
350 rprint(Markdown("\n".join(text)))