Coverage for src/canonical_imports/_core.py: 99%

212 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-02-05 19:37 +0100

1import ast 

2import collections 

3import copy 

4import difflib 

5from collections import defaultdict 

6from dataclasses import dataclass 

7from pathlib import Path 

8from typing import Optional 

9 

10import asttokens 

11import click 

12import structlog 

13from rich import print as rprint 

14from rich.markdown import Markdown 

15 

16from ._utils import unparse 

17 

18log = structlog.get_logger() 

19 

20 

21def is_private(name: str): 

22 return len(name) >= 2 and name[0] == "_" and name[1] != "_" 

23 

24 

25def is_module_private(module_name: str): 

26 return any(is_private(name) for name in module_name.split(".")) 

27 

28 

29@dataclass 

30class Import: 

31 module: str 

32 name: str 

33 asname: Optional[str] 

34 

35 def is_private(self): 

36 return is_private(self.name) or is_module_private(self.module) 

37 

38 @classmethod 

39 def from_ast(cls, stmt: ast.ImportFrom, module: str): 

40 for name in stmt.names: 

41 module_parts = [] 

42 if stmt.level != 0: 

43 module_parts = module.split(".")[: -(stmt.level)] 

44 

45 if stmt.module: 

46 module_parts.append(stmt.module) 

47 

48 yield Import(".".join(module_parts), name.name, name.asname or name.name) 

49 

50 def is_inside(self, other): 

51 return self.module.startswith(other.module + ".") 

52 

53 

54class Module: 

55 def __init__(self, path, import_fixer): 

56 self.path = path 

57 self.imports = [] 

58 self.import_fixer = import_fixer 

59 p = path 

60 m = [p.stem] 

61 p = p.parent 

62 

63 package_folder = p 

64 

65 while (p / "__init__.py").exists(): 

66 package_folder = p 

67 m.insert(0, p.name) 

68 p = p.parent 

69 

70 self.module = ".".join(m) 

71 self.module_parts = m 

72 

73 import_fixer.register(self, package_folder) 

74 

75 self.atok = None 

76 

77 try: 

78 self.atok = asttokens.ASTTokens(self.path.read_text(), parse=True) 

79 except SyntaxError: 

80 log.error(f"could not parse {self.path}") 

81 except UnicodeDecodeError: 

82 log.error(f"could not decode {self.path}") 

83 else: 

84 assigned_names = { 

85 name.id 

86 for name in ast.walk(self.tree) 

87 if isinstance(name, ast.Name) and isinstance(name.ctx, ast.Store) 

88 } 

89 

90 imported_names = collections.Counter( 

91 stmt.asname or stmt.name 

92 for stmt in ast.walk(self.tree) 

93 if isinstance(stmt, ast.alias) 

94 ) 

95 

96 double_imports = { 

97 name for name, count in imported_names.items() if count >= 2 

98 } 

99 

100 for stmt in self.tree.body: 

101 if isinstance(stmt, ast.ImportFrom): 

102 for i in Import.from_ast(stmt, self.module): 

103 if ( 

104 i.asname not in assigned_names 

105 and i.asname not in double_imports 

106 ): 

107 self.imports.append(i) 

108 

109 def is_init(self): 

110 return self.path.name == "__init__.py" 

111 

112 def relative_to_me(self, name): 

113 assert name[0] != "." 

114 assert name != self.module 

115 

116 self_parts = list(self.module_parts) 

117 name_parts = name.split(".") 

118 if name_parts[0] != self_parts[0]: 118 ↛ 119line 118 didn't jump to line 119, because the condition on line 118 was never true

119 return {"level": 0, "module": name} 

120 

121 while name_parts and self_parts and name_parts[0] == self_parts[0]: 

122 name_parts.pop(0) 

123 self_parts.pop(0) 

124 

125 return {"level": len(self_parts), "module": ".".join(name_parts)} 

126 

127 @property 

128 def tree(self): 

129 return self.atok.tree 

130 

131 def __repr__(self): 

132 return f"Module({self.path!r})" # pragma: no cover 

133 

134 def lookup(self, name) -> Optional[Import]: 

135 for import_ in self.imports: 

136 if import_.asname == name: 

137 return import_ 

138 

139 def change_set(self): 

140 changes = [] 

141 

142 if self.atok is None: 

143 return ChangeSet(self.path, []) 

144 

145 for stmt in ast.walk(self.tree): 

146 if isinstance(stmt, ast.ImportFrom): 

147 new_imports = {} # name -> module 

148 for first_import in Import.from_ast(stmt, self.module): 

149 last_import = first_import 

150 

151 step = 0 

152 

153 all_imports = [first_import] 

154 

155 skip = False 

156 

157 while True: 

158 step += 1 

159 if step > 50: 

160 skip = True 

161 break 

162 module = self.import_fixer.lookup_module(last_import.module) 

163 if module is None: 

164 break 

165 new_import = module.lookup(last_import.name) 

166 

167 if new_import != None: 

168 last_import = new_import 

169 all_imports.append(new_import) 

170 else: 

171 break 

172 

173 if skip: 

174 continue 

175 

176 for i in reversed(range(1, len(all_imports))): 

177 if self.import_fixer.is_allowed( 

178 self.module, all_imports[: i + 1] 

179 ): 

180 new_imports[first_import.asname] = all_imports[i] 

181 break 

182 

183 m = ast.Module(body=[], type_ignores=[]) 

184 

185 by_module = defaultdict(list) 

186 for import_name, import_ in new_imports.items(): 

187 by_module[import_.module].append((import_name, import_)) 

188 

189 for module, imports in by_module.items(): 

190 is_relative = stmt.level != 0 

191 m.body.append( 

192 ast.ImportFrom( 

193 **( 

194 self.relative_to_me(module) 

195 if is_relative 

196 else {"level": 0, "module": module} 

197 ), 

198 names=[ 

199 ast.alias( 

200 name=import_.name, 

201 asname=( 

202 import_name 

203 if import_name != import_.name 

204 else None 

205 ), 

206 ) 

207 for import_name, import_ in imports 

208 ], 

209 ) 

210 ) 

211 

212 new_stmt = copy.deepcopy(stmt) 

213 

214 new_stmt.names = [ 

215 n for n in stmt.names if (n.asname or n.name) not in new_imports 

216 ] 

217 

218 if new_stmt.names: 

219 m.body.append(new_stmt) 

220 

221 if new_imports: 

222 indent = " " * stmt.first_token.start[1] 

223 changes.append( 

224 ( 

225 stmt.first_token.startpos, 

226 stmt.last_token.endpos, 

227 unparse(m).replace("\n", "\n" + indent), 

228 ) 

229 ) 

230 

231 return ChangeSet(self.path, changes) 

232 

233 

234class ChangeSet: 

235 def __init__(self, path, changes): 

236 self.path = path 

237 self.changes = changes 

238 

239 def __bool__(self): 

240 return bool(self.changes) 

241 

242 def preview(self): 

243 old_code = self.path.read_text() 

244 new_code = asttokens.util.replace(old_code, self.changes) 

245 

246 diff = difflib.unified_diff( 

247 old_code.splitlines(), 

248 new_code.splitlines(), 

249 ) 

250 return diff 

251 

252 def fix(self): 

253 old_code = self.path.read_text() 

254 self.path.write_text(asttokens.util.replace(old_code, self.changes)) 

255 

256 

257class ImportFixer: 

258 def __init__(self, flags): 

259 self.module_cache = {} 

260 self.package_cache = {} # package_name -> folder 

261 self.flags = flags 

262 

263 def is_allowed(self, module, import_chain): 

264 if "public-private" in self.flags: 

265 last_import = import_chain[-1] 

266 if not is_module_private(module) and last_import.is_private(): 

267 return False 

268 

269 if "into-init" in self.flags: 

270 if any( 

271 self.is_init(init.module) 

272 and all(imp.is_inside(init) for imp in import_chain[i + 1 :]) 

273 for i, init in enumerate(import_chain[:-1]) 

274 ): 

275 return False 

276 return True 

277 

278 def register(self, module, package_folder): 

279 self.package_cache[module.module_parts[0]] = package_folder 

280 

281 self.module_cache[module.path] = module 

282 

283 def lookup_file(self, filename) -> Module: 

284 if filename not in self.module_cache: 

285 return Module(filename, self) 

286 return self.module_cache[filename] 

287 

288 def lookup_module(self, module: str) -> Optional[Module]: 

289 module_parts = module.split(".") 

290 package_name = module_parts[0] 

291 if package_name not in self.package_cache: 

292 return None 

293 

294 f = self.package_cache[package_name].joinpath(*module_parts[1:]) 

295 for option in (f.with_suffix(".py"), f / "__init__.py"): 

296 if option.exists(): 

297 return self.lookup_file(option) 

298 else: 

299 return 

300 

301 def is_init(self, module_name: str): 

302 module = self.lookup_module(module_name) 

303 return module is not None and module.is_init() 

304 

305 

306@click.command() 

307@click.option( 

308 "--no", 

309 help="Exclude specific imports", 

310 multiple=True, 

311 type=click.Choice(["public-private", "into-init"]), 

312) 

313@click.option("--write", "-w", is_flag=True, help="write changed imports") 

314@click.argument("paths", nargs=-1, type=click.Path(exists=True)) 

315def main(no, paths, write): 

316 """`canonical-imports` follows your imports and finds out where the things 

317 you are importing are actually defined. 

318 

319 PATHS: python files or directories with should be scanned for python files 

320 """ 

321 import_fixer = ImportFixer(set(no)) 

322 

323 files = [] 

324 for file in paths: 

325 file = Path(file) 

326 if file.is_dir(): 

327 files += list(file.rglob("*.py")) 

328 else: 

329 files.append(file) 

330 

331 source_files = [Module(file, import_fixer) for file in files] 

332 

333 if write: 

334 for file in source_files: 

335 change_set = file.change_set() 

336 if change_set: 

337 print(f"fix: {file.path}") 

338 change_set.fix() 

339 else: 

340 text = [] 

341 for file in source_files: 

342 change_set = file.change_set() 

343 if change_set: 

344 text.append(f"## {file.path}") 

345 text.append("```diff") 

346 text += list(change_set.preview())[2:] 

347 text.append("```") 

348 text.append("---") 

349 

350 rprint(Markdown("\n".join(text)))