Coverage for src/pycse/colab.py: 13.17%

319 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 16:23 -0400

1"""Module for use in Google Colab.""" 

2 

3from datetime import datetime 

4import glob 

5import io 

6import os 

7import shlex 

8import shutil 

9import subprocess 

10import tempfile 

11from urllib.parse import urlparse 

12from socket import gethostname, gethostbyname 

13 

14from IPython import display 

15from IPython.core.magic import register_line_magic 

16from IPython.display import HTML, IFrame 

17from nbconvert import HTMLExporter, PDFExporter 

18import nbformat 

19 

20import requests 

21 

22try: 

23 from google.colab import drive 

24 from google.colab import files 

25 from googleapiclient.http import MediaIoBaseDownload 

26 

27 from google.colab import auth 

28 from googleapiclient.discovery import build 

29except ModuleNotFoundError: 

30 pass 

31 

32DRIVE = None 

33 

34 

35def gdrive(): 

36 """Get the drive service, authenticate if needed.""" 

37 global DRIVE 

38 if DRIVE is None: 

39 auth.authenticate_user() 

40 DRIVE = build("drive", "v3") 

41 return DRIVE 

42 

43 

44################################################################## 

45# Utilities 

46################################################################## 

47 

48 

49def aptupdate(): 

50 """Run apt-get update.""" 

51 s = subprocess.run(["apt-get", "update"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 

52 if s.returncode != 0: 

53 raise Exception(f"apt-get update failed.\n{s.stdout.decode()}\n{s.stderr.decode()}") 

54 

55 

56def aptinstall(apt_pkg): 

57 """Install a package and check for success.""" 

58 print(f"Installing {apt_pkg}. Please be patient.") 

59 s = subprocess.run( 

60 ["apt-get", "install", apt_pkg], 

61 stdout=subprocess.PIPE, 

62 stderr=subprocess.PIPE, 

63 ) 

64 if s.returncode != 0: 

65 raise Exception(f"{apt_pkg} installation failed.\n{s.stdout.decode()}\n{s.stderr.decode()}") 

66 

67 

68# @register_line_magic 

69# def gimport(fid_or_url): 

70# '''Load the python code at fid or url. 

71# Also a line magic.''' 

72# with gopen(fid_or_url) as f: 

73# py = f.read() 

74# g = globals() 

75# exec(py, g) 

76 

77################################################################## 

78# Exporting functions 

79################################################################## 

80 

81 

82def current_notebook(): 

83 """Return current notebook name and file id.""" 

84 ip = gethostbyname(gethostname()) 

85 url = f"http://{ip}:9000/api/sessions" 

86 d = requests.get(url).json()[0] 

87 fid = d["path"].split("=")[1] 

88 fname = d["name"] 

89 return fname, fid 

90 

91 

92def notebook_string(fid): 

93 """Return noteook json data in string form for notebook at FID.""" 

94 drive_service = gdrive() 

95 request = drive_service.files().get_media(fileId=fid) 

96 downloaded = io.BytesIO() 

97 downloader = MediaIoBaseDownload(downloaded, request) 

98 done = False 

99 while done is False: 

100 # _ is a placeholder for a progress object that we ignore. 

101 # (Our file is small, so we skip reporting progress.) 

102 _, done = downloader.next_chunk() 

103 

104 # Rewind 

105 downloaded.seek(0) 

106 ipynb = downloaded.read() # nb in string form 

107 return ipynb 

108 

109 

110def pdf_from_html(pdf=None, verbose=False, plotly=False, javascript_delay=10000): 

111 """Export the current notebook as a PDF. 

112 

113 pdf is the name of the PDF to export. 

114 plotly uses the plotly exporter 

115 The pdf is not saved in GDrive. Conversion is done from an HTML export. 

116 javascript_delay is in ms, and is how long to wait in wkhtmltopdf to let 

117 javascript, especially mathjax finish. 

118 """ 

119 if verbose: 

120 print("PDF via wkhtmltopdf") 

121 

122 fname, fid = current_notebook() 

123 ipynb = notebook_string(fid) 

124 

125 if plotly: 

126 subprocess.run(["pip", "install", "plotlyhtmlexporter"]) 

127 from plotlyhtmlexporter import PlotlyHTMLExporter 

128 

129 exporter = PlotlyHTMLExporter() 

130 else: 

131 exporter = HTMLExporter() 

132 

133 nb = nbformat.reads(ipynb, as_version=4) 

134 body, resources = exporter.from_notebook_node(nb) 

135 

136 if verbose: 

137 print(f"args: pdf={pdf}, verbose={verbose}") 

138 

139 if pdf is None: 

140 html = fname.replace(".ipynb", ".html") 

141 pdf = html.replace(".html", ".pdf") 

142 else: 

143 html = pdf.replace(".pdf", ".html") 

144 

145 if verbose: 

146 print(f"using html = {html}") 

147 

148 tmpdirname = tempfile.TemporaryDirectory().name 

149 

150 if not os.path.isdir(tmpdirname): 

151 os.mkdir(tmpdirname) 

152 

153 ahtml = os.path.join(tmpdirname, html) 

154 apdf = os.path.join(tmpdirname, pdf) 

155 css = os.path.join(tmpdirname, "custom.css") 

156 

157 with open(ahtml, "w") as f: 

158 f.write(body) 

159 

160 with open(css, "w") as f: 

161 f.write("\n".join(resources["inlining"]["css"])) 

162 

163 aptupdate() 

164 

165 if not shutil.which("xvfb-run"): 

166 aptinstall("xvfb") 

167 

168 if not shutil.which("wkhtmltopdf"): 

169 aptinstall("wkhtmltopdf") 

170 

171 if verbose: 

172 print(f"Running with delay: {javascript_delay}") 

173 

174 s = subprocess.run( 

175 [ 

176 "xvfb-run", 

177 "wkhtmltopdf", 

178 "--enable-javascript", 

179 "--no-stop-slow-scripts", 

180 "--javascript-delay", 

181 str(javascript_delay), 

182 ahtml, 

183 apdf, 

184 ], 

185 stdout=subprocess.PIPE, 

186 stderr=subprocess.PIPE, 

187 ) 

188 

189 if verbose and s.returncode != 0: 

190 print( 

191 f"Conversion exited with non-zero status: {s.returncode}.\n" 

192 f"{s.stdout.decode()}\n" 

193 f"{s.stderr.decode()}" 

194 ) 

195 

196 if os.path.exists(apdf): 

197 files.download(apdf) 

198 else: 

199 print("no pdf found.") 

200 print(ahtml) 

201 print(apdf) 

202 

203 

204def pdf_from_latex(pdf=None, verbose=False): 

205 """Export the notebook to PDF via LaTeX. 

206 

207 This is not fast because you have to install texlive. 

208 verbose is not used right now. 

209 """ 

210 print("PDF via LaTeX") 

211 if not shutil.which("xelatex"): 

212 aptinstall("texlive-xetex") 

213 

214 fname, fid = current_notebook() 

215 ipynb = notebook_string(fid) 

216 

217 exporter = PDFExporter() 

218 

219 nb = nbformat.reads(ipynb, as_version=4) 

220 body, resources = exporter.from_notebook_node(nb) 

221 

222 if pdf is None: 

223 pdf = fname.replace(".ipynb", ".pdf") 

224 

225 tmpdirname = tempfile.TemporaryDirectory().name 

226 

227 if not os.path.isdir(tmpdirname): 

228 os.mkdir(tmpdirname) 

229 

230 apdf = os.path.join(tmpdirname, pdf) 

231 

232 if os.path.exists(apdf): 

233 os.unlink(apdf) 

234 

235 with open(apdf, "wb") as f: 

236 f.write(body) 

237 

238 if os.path.exists(apdf): 

239 files.download(apdf) 

240 else: 

241 print(f"{apdf} not found") 

242 

243 

244def pdf(line=""): 

245 """Line magic to export a colab to PDF. 

246 

247 You can have an optional arg -l to use LaTeX, defaults to html->PDF. This 

248 takes longer to install, and may not work if you use non-standard LaTeX 

249 code. I do not know how to add custom LaTeX packages to use arbitrary 

250 commands. 

251 

252 You can have an optional arg -d integer for a delay in seconds for the html 

253 to pdf. This is helpful when some equations are not rendering with 

254 html->PDF. The rendering is done by MathJax, and notebooks with a lot of 

255 equations take longer to render. 

256 

257 You can have an optional last argument for the filename of the pdf to save 

258 to. 

259 

260 Known limitations: 

261 1. If your notebook name doesn't end with .ipynb this does not work. 

262 

263 """ 

264 args = shlex.split(line) 

265 

266 if args and args[-1].endswith(".pdf"): 

267 pdf = args[-1] 

268 else: 

269 pdf = None 

270 

271 verbose = "-v" in args 

272 

273 if verbose: 

274 print(f"%pdf args = {args}") 

275 

276 if "-l" in args: 

277 pdf_from_latex(pdf, verbose) 

278 

279 else: 

280 if "-d" in args: 

281 i = args.index("-d") 

282 # The delay should be in microseconds. 

283 delay = int(args[i + 1]) * 1000 

284 else: 

285 delay = 10000 

286 plotly = "-p" in args 

287 pdf_from_html(pdf, verbose, plotly, delay) 

288 

289 

290# this is hackery so that CI works. 

291# it is an error to do this when there is not IPython 

292try: 

293 pdf = register_line_magic(pdf) 

294except: # noqa: E722 

295 pass 

296 

297 

298################################################################## 

299# File utilities 

300################################################################## 

301 

302 

303def fid_from_url(url): 

304 """Return a file ID for a file on GDrive from its url.""" 

305 u = urlparse(url) 

306 

307 # This is a typical sharing link 

308 # https://drive.google.com/file/d/1q_qE9RGdfV_8Vv3zuApf-LqXBwqo8HO2/view?usp=sharing 

309 if (u.netloc == "drive.google.com") and (u.path.startswith("/file/d/")): 

310 return u.path.split("/")[3] 

311 

312 # This is a download link 

313 # https://drive.google.com/uc?id=1LLOGvaXsaEhUQXd7AmN_offy2IzNEu0K 

314 elif (u.netloc == "drive.google.com") and (u.path == "/uc"): 

315 q = u.query 

316 # I think this could have other things separated by & 

317 qs = q.split("&") 

318 for item in qs: 

319 if item.startswith("id="): 

320 return item[3:] 

321 

322 # A colab url 

323 # https://colab.research.google.com/drive/1YcD5OXL-CNBO2h_OXZFb-mY6-LqgcLkB#scrollTo=0qkiF99z01pc 

324 elif u.netloc == "colab.research.google.com": 

325 return u.path.split("/")[2] 

326 

327 # 'https://docs.google.com/document/d/1lvDK2GisDM5aBnImtHNwOmLsU9jxg1NaPC46rB4bVqw/edit?usp=sharing' 

328 elif (u.netloc == "docs.google.com") and u.path.startswith("/document/d/"): 

329 p = u.path 

330 p = p.replace("/document/d/", "") 

331 p = p.replace("/edit", "") 

332 return p 

333 

334 # https://docs.google.com/spreadsheets/d/1qSaBe73Pd8L3jJyOL68klp6yRArW7Nce/edit#gid=1923176268 

335 elif (u.netloc == "docs.google.com") and u.path.startswith("/spreadsheets/d/"): 

336 p = u.path 

337 p = p.replace("/spreadsheets/d/", "") 

338 p = p.replace("/edit", "") 

339 return p 

340 

341 # https://docs.google.com/presentation/d/1poP1gvWlfeZCR_5FsIzlRPMAYlBUR827wKPjbWGzW9M/edit#slide=id.p 

342 elif (u.netloc == "docs.google.com") and u.path.startswith("/presentation/d/"): 

343 p = u.path 

344 p = p.replace("/presentation/d/", "") 

345 p = p.replace("/edit", "") 

346 return p 

347 

348 # https://drive.google.com/drive/u/0/folders/1aTs-_bhjT1GXy2P2hStzn31qAihRq2sl 

349 elif (u.netloc == "drive.google.com") and "folders" in u.path: 

350 return u.path.split("folders")[1][1:] 

351 

352 else: 

353 raise Exception(f"Cannot parse {url} yet.") 

354 

355 

356def gopen(fid_or_url_or_path, mode="r"): 

357 """Open a file on Gdrive by its ID, sharing link or path. 

358 

359 Returns a file-like object you can read from. 

360 Note this reads the whole file into memory, so it may not 

361 be good for large files. Returns an io.StringIO if mode is "r" 

362 or io.BytesIO if mode is "rb". 

363 """ 

364 if mode not in ["r", "rb"]: 

365 raise Exception('mode must be "r" or "rb"') 

366 

367 if fid_or_url_or_path.startswith("http"): 

368 fid = fid_from_url(fid_or_url_or_path) 

369 else: 

370 # it could be a path 

371 if os.path.isfile(fid_or_url_or_path): 

372 fid = get_id(fid_or_url_or_path) 

373 else: 

374 # assume it is an fid 

375 fid = fid_or_url_or_path 

376 print("fid: ", fid) 

377 

378 drive_service = gdrive() 

379 request = drive_service.files().get_media(fileId=fid) 

380 downloaded = io.BytesIO() 

381 downloader = MediaIoBaseDownload(downloaded, request) 

382 done = False 

383 while done is False: 

384 # _ is a placeholder for a progress object that we ignore. 

385 # (Our file is small, so we skip reporting progress.) 

386 _, done = downloader.next_chunk() 

387 

388 # I prefer strings to bytes. 

389 downloaded.seek(0) 

390 if mode == "r": 

391 return io.TextIOWrapper(downloaded) 

392 else: 

393 return downloaded 

394 

395 

396# Path utilities 

397# This is tricky, paths are not deterministic in GDrive the way we are used to. 

398# There is also some differences in My Drive and Shared drives, and files 

399# shared with you. 

400 

401 

402def get_path(fid_or_url): 

403 """Return the path to an fid or url. 

404 

405 The path i's relative to the mount point. 

406 """ 

407 if fid_or_url.startswith("http"): 

408 fid = fid_from_url(fid_or_url) 

409 else: 

410 fid = fid_or_url 

411 

412 drive_service = gdrive() 

413 x = ( 

414 drive_service.files() 

415 .get(fileId=fid, supportsAllDrives=True, fields="parents,name") 

416 .execute() 

417 ) 

418 

419 dirs = [x["name"]] # start with the document name 

420 

421 while x.get("parents", None): 

422 if len(x["parents"]) > 1: 

423 print(f"Warning, multiple parents found {x['parents']}") 

424 

425 x = ( 

426 drive_service.files() 

427 .get( 

428 fileId=x["parents"][0], 

429 supportsAllDrives=True, 

430 fields="id,parents,name", 

431 ) 

432 .execute() 

433 ) 

434 

435 if ("parents" not in x) and x["name"] == "Drive": 

436 # this means your file is in a shared drive I think. 

437 drives = drive_service.drives().list().execute()["drives"] 

438 for drv in drives: 

439 if drv["id"] == x["id"]: 

440 dirs += [drv["name"], "Shared drives"] 

441 else: 

442 dirs += [x["name"]] 

443 

444 if not os.path.isdir("/gdrive"): 

445 drive.mount("/gdrive") 

446 

447 dirs += ["/gdrive"] 

448 

449 dirs.reverse() 

450 p = os.path.sep.join(dirs) 

451 

452 # Sometimes, it appears we are missing an extension, because the name does 

453 # not always include the extension. We glob through matches to get the match 

454 # in this case. 

455 if not os.path.exists(p): 

456 for f in glob.glob(f"{p}*"): 

457 if get_id(f) == fid: 

458 return f 

459 else: 

460 return p 

461 

462 

463def get_id(path): 

464 """Given a path, return an id to it.""" 

465 drive_service = gdrive() 

466 

467 if not shutil.which("xattr"): 

468 aptinstall("xattr") 

469 

470 path = os.path.abspath(path) 

471 

472 if os.path.isfile(path): 

473 return subprocess.getoutput(f"xattr -p 'user.drive.id' '{path}'") 

474 

475 elif os.path.isdir(path): 

476 # Strip the / gdrive off 

477 path = path.split("/")[2:] 

478 

479 if path[0] == "My Drive" and len(path) == 1: 

480 return 0 

481 

482 if path[0] == "My Drive": 

483 drive_id = "root" 

484 id = "root" 

485 

486 elif path[0] == "Shared drives": 

487 drives = drive_service.drives().list().execute()["drives"] 

488 for drv in drives: 

489 if drv["name"] == path[1]: 

490 drive_id = drv["id"] 

491 id = drv["id"] 

492 break 

493 

494 path = path[1:] 

495 

496 found = False 

497 for d in path: 

498 dsf = drive_service.files() 

499 args = dict(q=f"'{id}' in parents") 

500 if drive_id != "root": 

501 args["corpora"] = "drive" 

502 args["supportsAllDrives"] = True 

503 args["includeItemsFromAllDrives"] = True 

504 args["driveId"] = drive_id 

505 

506 file_list = dsf.list(**args).execute() 

507 

508 found = False 

509 for file1 in file_list.get("files", []): 

510 if file1["name"] == d: 

511 found = True 

512 id = file1["id"] 

513 break 

514 

515 if found: 

516 return id 

517 

518 else: 

519 raise Exception(f"Something went wrong with {path}") 

520 

521 else: 

522 raise Exception(f"{path} does not seem to be a file or directory") 

523 

524 

525def get_link(path): 

526 """Return a clickable link for path.""" 

527 fid = get_id(os.path.abspath(path)) 

528 drive_service = gdrive() 

529 x = ( 

530 drive_service.files() 

531 .get(fileId=fid, supportsAllDrives=True, fields="webViewLink") 

532 .execute() 

533 ) 

534 url = x.get("webViewLink", "No web link found") 

535 return HTML(f"<a href={url} target=_blank>{path}</a>") 

536 

537 

538def gchdir(path=None): 

539 """Change working dir to path. 

540 

541 if path is None, default to working directory of current notebook. 

542 """ 

543 if path is None: 

544 path = os.path.dirname(get_path(current_notebook()[1])) 

545 

546 if os.path.isabs(path): 

547 os.chdir(path) 

548 else: 

549 os.chdir(os.path.abspath(path)) 

550 

551 

552def gdownload(*FILES, **kwargs): 

553 """Download files. Each arg can be a path, or pattern. 

554 

555 If you have more than one file, a zip is downloaded. 

556 You can specify a zip file name as a kwarg: 

557 

558 gdownload("*", zip="test.zip") 

559 

560 The zip file will be deleted unless you use keep=True as a kwarg. 

561 

562 """ 

563 fd = [] 

564 for f in FILES: 

565 for g in glob.glob(f): 

566 fd += [g] 

567 

568 if (len(fd) == 1) and (os.path.isfile(fd[0])): 

569 files.download(fd[0]) 

570 else: 

571 if "zip" in kwargs: 

572 zipfile = kwargs["zip"] 

573 else: 

574 now = datetime.now() 

575 zipfile = now.strftime("%m-%d-%YT%H-%M-%S.zip") 

576 

577 if os.path.exists(zipfile): 

578 os.unlink(zipfile) 

579 

580 s = subprocess.run( 

581 ["zip", "-r", zipfile, *fd], 

582 stdout=subprocess.PIPE, 

583 stderr=subprocess.PIPE, 

584 ) 

585 if s.returncode != 0: 

586 print(f"zip did not fully succeed:\n{s.stdout.decode()}\n{s.stderr.decode()}\n") 

587 files.download(zipfile) 

588 

589 

590# if not kwargs.get('keep', False): 

591# os.unlink(zip) 

592 

593 

594################################################################## 

595# Get to a shell 

596################################################################## 

597def gconsole(): 

598 """Open a shell in colab. 

599 

600 Adapted from 

601 https://github.com/airesearch-in-th/kora/blob/master/kora/console.py 

602 """ 

603 url = ( 

604 "https://github.com/gravitational/teleconsole/releases/download" 

605 "/0.4.0/teleconsole-v0.4.0-linux-amd64.tar.gz" 

606 ) 

607 os.system(f"curl -L {url} | tar xz") # download & extract 

608 os.system("mv teleconsole /usr/local/bin/") # in PATH 

609 

610 # Set PS1, directory 

611 with open("/root/.bashrc", "a") as f: 

612 f.write(r'PS1="\e[1;36m\w\e[m# "\n') 

613 f.write("cd /content \n") 

614 f.write( 

615 "PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:" 

616 "/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/tools/node/bin:" 

617 "/tools/google-cloud-sdk/bin:/opt/bin \n" 

618 ) 

619 

620 process = subprocess.Popen( 

621 "teleconsole", 

622 shell=True, 

623 stdin=subprocess.PIPE, 

624 stdout=subprocess.PIPE, 

625 stderr=subprocess.PIPE, 

626 ) 

627 for _ in range(6): 

628 line = process.stdout.readline() 

629 

630 url = line.decode().strip().split()[-1] 

631 print("Console URL:", url) 

632 return IFrame(url, width=800, height=600) 

633 

634 

635################################################################## 

636# Fancy outputs 

637################################################################## 

638 

639 

640def gsuite(fid_or_url, width=1200, height=1000): 

641 """Return an iframe that renders the item in a colab.""" 

642 drive_service = gdrive() 

643 if fid_or_url.startswith("http"): 

644 url = fid_or_url 

645 else: 

646 # Assume we have an fid 

647 x = ( 

648 drive_service.files() 

649 .get(fileId=fid_or_url, supportsAllDrives=True, fields="webViewLink") 

650 .execute() 

651 ) 

652 url = x.get("webViewLink", "No web link found.") 

653 

654 display(HTML(f"""<a href="{url}" target="_blank">Link</a><br>""")) 

655 

656 g = requests.get(url) 

657 xframeoptions = g.headers.get("X-Frame-Options", "").lower() 

658 if xframeoptions in ["deny", "sameorigin"]: 

659 print(f"X-Frame-Option = {xframeoptions}\nEmbedding in IFrame is not allowed for {url}.") 

660 return None 

661 else: 

662 return IFrame(url, width, height)