Coverage for src/pycse/colab.py: 13.17%
319 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""Module for use in Google Colab."""
3from datetime import datetime
4import glob
5import io
6import os
7import shlex
8import shutil
9import subprocess
10import tempfile
11from urllib.parse import urlparse
12from socket import gethostname, gethostbyname
14from IPython import display
15from IPython.core.magic import register_line_magic
16from IPython.display import HTML, IFrame
17from nbconvert import HTMLExporter, PDFExporter
18import nbformat
20import requests
22try:
23 from google.colab import drive
24 from google.colab import files
25 from googleapiclient.http import MediaIoBaseDownload
27 from google.colab import auth
28 from googleapiclient.discovery import build
29except ModuleNotFoundError:
30 pass
32DRIVE = None
35def gdrive():
36 """Get the drive service, authenticate if needed."""
37 global DRIVE
38 if DRIVE is None:
39 auth.authenticate_user()
40 DRIVE = build("drive", "v3")
41 return DRIVE
44##################################################################
45# Utilities
46##################################################################
49def aptupdate():
50 """Run apt-get update."""
51 s = subprocess.run(["apt-get", "update"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
52 if s.returncode != 0:
53 raise Exception(f"apt-get update failed.\n{s.stdout.decode()}\n{s.stderr.decode()}")
56def aptinstall(apt_pkg):
57 """Install a package and check for success."""
58 print(f"Installing {apt_pkg}. Please be patient.")
59 s = subprocess.run(
60 ["apt-get", "install", apt_pkg],
61 stdout=subprocess.PIPE,
62 stderr=subprocess.PIPE,
63 )
64 if s.returncode != 0:
65 raise Exception(f"{apt_pkg} installation failed.\n{s.stdout.decode()}\n{s.stderr.decode()}")
68# @register_line_magic
69# def gimport(fid_or_url):
70# '''Load the python code at fid or url.
71# Also a line magic.'''
72# with gopen(fid_or_url) as f:
73# py = f.read()
74# g = globals()
75# exec(py, g)
77##################################################################
78# Exporting functions
79##################################################################
82def current_notebook():
83 """Return current notebook name and file id."""
84 ip = gethostbyname(gethostname())
85 url = f"http://{ip}:9000/api/sessions"
86 d = requests.get(url).json()[0]
87 fid = d["path"].split("=")[1]
88 fname = d["name"]
89 return fname, fid
92def notebook_string(fid):
93 """Return noteook json data in string form for notebook at FID."""
94 drive_service = gdrive()
95 request = drive_service.files().get_media(fileId=fid)
96 downloaded = io.BytesIO()
97 downloader = MediaIoBaseDownload(downloaded, request)
98 done = False
99 while done is False:
100 # _ is a placeholder for a progress object that we ignore.
101 # (Our file is small, so we skip reporting progress.)
102 _, done = downloader.next_chunk()
104 # Rewind
105 downloaded.seek(0)
106 ipynb = downloaded.read() # nb in string form
107 return ipynb
110def pdf_from_html(pdf=None, verbose=False, plotly=False, javascript_delay=10000):
111 """Export the current notebook as a PDF.
113 pdf is the name of the PDF to export.
114 plotly uses the plotly exporter
115 The pdf is not saved in GDrive. Conversion is done from an HTML export.
116 javascript_delay is in ms, and is how long to wait in wkhtmltopdf to let
117 javascript, especially mathjax finish.
118 """
119 if verbose:
120 print("PDF via wkhtmltopdf")
122 fname, fid = current_notebook()
123 ipynb = notebook_string(fid)
125 if plotly:
126 subprocess.run(["pip", "install", "plotlyhtmlexporter"])
127 from plotlyhtmlexporter import PlotlyHTMLExporter
129 exporter = PlotlyHTMLExporter()
130 else:
131 exporter = HTMLExporter()
133 nb = nbformat.reads(ipynb, as_version=4)
134 body, resources = exporter.from_notebook_node(nb)
136 if verbose:
137 print(f"args: pdf={pdf}, verbose={verbose}")
139 if pdf is None:
140 html = fname.replace(".ipynb", ".html")
141 pdf = html.replace(".html", ".pdf")
142 else:
143 html = pdf.replace(".pdf", ".html")
145 if verbose:
146 print(f"using html = {html}")
148 tmpdirname = tempfile.TemporaryDirectory().name
150 if not os.path.isdir(tmpdirname):
151 os.mkdir(tmpdirname)
153 ahtml = os.path.join(tmpdirname, html)
154 apdf = os.path.join(tmpdirname, pdf)
155 css = os.path.join(tmpdirname, "custom.css")
157 with open(ahtml, "w") as f:
158 f.write(body)
160 with open(css, "w") as f:
161 f.write("\n".join(resources["inlining"]["css"]))
163 aptupdate()
165 if not shutil.which("xvfb-run"):
166 aptinstall("xvfb")
168 if not shutil.which("wkhtmltopdf"):
169 aptinstall("wkhtmltopdf")
171 if verbose:
172 print(f"Running with delay: {javascript_delay}")
174 s = subprocess.run(
175 [
176 "xvfb-run",
177 "wkhtmltopdf",
178 "--enable-javascript",
179 "--no-stop-slow-scripts",
180 "--javascript-delay",
181 str(javascript_delay),
182 ahtml,
183 apdf,
184 ],
185 stdout=subprocess.PIPE,
186 stderr=subprocess.PIPE,
187 )
189 if verbose and s.returncode != 0:
190 print(
191 f"Conversion exited with non-zero status: {s.returncode}.\n"
192 f"{s.stdout.decode()}\n"
193 f"{s.stderr.decode()}"
194 )
196 if os.path.exists(apdf):
197 files.download(apdf)
198 else:
199 print("no pdf found.")
200 print(ahtml)
201 print(apdf)
204def pdf_from_latex(pdf=None, verbose=False):
205 """Export the notebook to PDF via LaTeX.
207 This is not fast because you have to install texlive.
208 verbose is not used right now.
209 """
210 print("PDF via LaTeX")
211 if not shutil.which("xelatex"):
212 aptinstall("texlive-xetex")
214 fname, fid = current_notebook()
215 ipynb = notebook_string(fid)
217 exporter = PDFExporter()
219 nb = nbformat.reads(ipynb, as_version=4)
220 body, resources = exporter.from_notebook_node(nb)
222 if pdf is None:
223 pdf = fname.replace(".ipynb", ".pdf")
225 tmpdirname = tempfile.TemporaryDirectory().name
227 if not os.path.isdir(tmpdirname):
228 os.mkdir(tmpdirname)
230 apdf = os.path.join(tmpdirname, pdf)
232 if os.path.exists(apdf):
233 os.unlink(apdf)
235 with open(apdf, "wb") as f:
236 f.write(body)
238 if os.path.exists(apdf):
239 files.download(apdf)
240 else:
241 print(f"{apdf} not found")
244def pdf(line=""):
245 """Line magic to export a colab to PDF.
247 You can have an optional arg -l to use LaTeX, defaults to html->PDF. This
248 takes longer to install, and may not work if you use non-standard LaTeX
249 code. I do not know how to add custom LaTeX packages to use arbitrary
250 commands.
252 You can have an optional arg -d integer for a delay in seconds for the html
253 to pdf. This is helpful when some equations are not rendering with
254 html->PDF. The rendering is done by MathJax, and notebooks with a lot of
255 equations take longer to render.
257 You can have an optional last argument for the filename of the pdf to save
258 to.
260 Known limitations:
261 1. If your notebook name doesn't end with .ipynb this does not work.
263 """
264 args = shlex.split(line)
266 if args and args[-1].endswith(".pdf"):
267 pdf = args[-1]
268 else:
269 pdf = None
271 verbose = "-v" in args
273 if verbose:
274 print(f"%pdf args = {args}")
276 if "-l" in args:
277 pdf_from_latex(pdf, verbose)
279 else:
280 if "-d" in args:
281 i = args.index("-d")
282 # The delay should be in microseconds.
283 delay = int(args[i + 1]) * 1000
284 else:
285 delay = 10000
286 plotly = "-p" in args
287 pdf_from_html(pdf, verbose, plotly, delay)
290# this is hackery so that CI works.
291# it is an error to do this when there is not IPython
292try:
293 pdf = register_line_magic(pdf)
294except: # noqa: E722
295 pass
298##################################################################
299# File utilities
300##################################################################
303def fid_from_url(url):
304 """Return a file ID for a file on GDrive from its url."""
305 u = urlparse(url)
307 # This is a typical sharing link
308 # https://drive.google.com/file/d/1q_qE9RGdfV_8Vv3zuApf-LqXBwqo8HO2/view?usp=sharing
309 if (u.netloc == "drive.google.com") and (u.path.startswith("/file/d/")):
310 return u.path.split("/")[3]
312 # This is a download link
313 # https://drive.google.com/uc?id=1LLOGvaXsaEhUQXd7AmN_offy2IzNEu0K
314 elif (u.netloc == "drive.google.com") and (u.path == "/uc"):
315 q = u.query
316 # I think this could have other things separated by &
317 qs = q.split("&")
318 for item in qs:
319 if item.startswith("id="):
320 return item[3:]
322 # A colab url
323 # https://colab.research.google.com/drive/1YcD5OXL-CNBO2h_OXZFb-mY6-LqgcLkB#scrollTo=0qkiF99z01pc
324 elif u.netloc == "colab.research.google.com":
325 return u.path.split("/")[2]
327 # 'https://docs.google.com/document/d/1lvDK2GisDM5aBnImtHNwOmLsU9jxg1NaPC46rB4bVqw/edit?usp=sharing'
328 elif (u.netloc == "docs.google.com") and u.path.startswith("/document/d/"):
329 p = u.path
330 p = p.replace("/document/d/", "")
331 p = p.replace("/edit", "")
332 return p
334 # https://docs.google.com/spreadsheets/d/1qSaBe73Pd8L3jJyOL68klp6yRArW7Nce/edit#gid=1923176268
335 elif (u.netloc == "docs.google.com") and u.path.startswith("/spreadsheets/d/"):
336 p = u.path
337 p = p.replace("/spreadsheets/d/", "")
338 p = p.replace("/edit", "")
339 return p
341 # https://docs.google.com/presentation/d/1poP1gvWlfeZCR_5FsIzlRPMAYlBUR827wKPjbWGzW9M/edit#slide=id.p
342 elif (u.netloc == "docs.google.com") and u.path.startswith("/presentation/d/"):
343 p = u.path
344 p = p.replace("/presentation/d/", "")
345 p = p.replace("/edit", "")
346 return p
348 # https://drive.google.com/drive/u/0/folders/1aTs-_bhjT1GXy2P2hStzn31qAihRq2sl
349 elif (u.netloc == "drive.google.com") and "folders" in u.path:
350 return u.path.split("folders")[1][1:]
352 else:
353 raise Exception(f"Cannot parse {url} yet.")
356def gopen(fid_or_url_or_path, mode="r"):
357 """Open a file on Gdrive by its ID, sharing link or path.
359 Returns a file-like object you can read from.
360 Note this reads the whole file into memory, so it may not
361 be good for large files. Returns an io.StringIO if mode is "r"
362 or io.BytesIO if mode is "rb".
363 """
364 if mode not in ["r", "rb"]:
365 raise Exception('mode must be "r" or "rb"')
367 if fid_or_url_or_path.startswith("http"):
368 fid = fid_from_url(fid_or_url_or_path)
369 else:
370 # it could be a path
371 if os.path.isfile(fid_or_url_or_path):
372 fid = get_id(fid_or_url_or_path)
373 else:
374 # assume it is an fid
375 fid = fid_or_url_or_path
376 print("fid: ", fid)
378 drive_service = gdrive()
379 request = drive_service.files().get_media(fileId=fid)
380 downloaded = io.BytesIO()
381 downloader = MediaIoBaseDownload(downloaded, request)
382 done = False
383 while done is False:
384 # _ is a placeholder for a progress object that we ignore.
385 # (Our file is small, so we skip reporting progress.)
386 _, done = downloader.next_chunk()
388 # I prefer strings to bytes.
389 downloaded.seek(0)
390 if mode == "r":
391 return io.TextIOWrapper(downloaded)
392 else:
393 return downloaded
396# Path utilities
397# This is tricky, paths are not deterministic in GDrive the way we are used to.
398# There is also some differences in My Drive and Shared drives, and files
399# shared with you.
402def get_path(fid_or_url):
403 """Return the path to an fid or url.
405 The path i's relative to the mount point.
406 """
407 if fid_or_url.startswith("http"):
408 fid = fid_from_url(fid_or_url)
409 else:
410 fid = fid_or_url
412 drive_service = gdrive()
413 x = (
414 drive_service.files()
415 .get(fileId=fid, supportsAllDrives=True, fields="parents,name")
416 .execute()
417 )
419 dirs = [x["name"]] # start with the document name
421 while x.get("parents", None):
422 if len(x["parents"]) > 1:
423 print(f"Warning, multiple parents found {x['parents']}")
425 x = (
426 drive_service.files()
427 .get(
428 fileId=x["parents"][0],
429 supportsAllDrives=True,
430 fields="id,parents,name",
431 )
432 .execute()
433 )
435 if ("parents" not in x) and x["name"] == "Drive":
436 # this means your file is in a shared drive I think.
437 drives = drive_service.drives().list().execute()["drives"]
438 for drv in drives:
439 if drv["id"] == x["id"]:
440 dirs += [drv["name"], "Shared drives"]
441 else:
442 dirs += [x["name"]]
444 if not os.path.isdir("/gdrive"):
445 drive.mount("/gdrive")
447 dirs += ["/gdrive"]
449 dirs.reverse()
450 p = os.path.sep.join(dirs)
452 # Sometimes, it appears we are missing an extension, because the name does
453 # not always include the extension. We glob through matches to get the match
454 # in this case.
455 if not os.path.exists(p):
456 for f in glob.glob(f"{p}*"):
457 if get_id(f) == fid:
458 return f
459 else:
460 return p
463def get_id(path):
464 """Given a path, return an id to it."""
465 drive_service = gdrive()
467 if not shutil.which("xattr"):
468 aptinstall("xattr")
470 path = os.path.abspath(path)
472 if os.path.isfile(path):
473 return subprocess.getoutput(f"xattr -p 'user.drive.id' '{path}'")
475 elif os.path.isdir(path):
476 # Strip the / gdrive off
477 path = path.split("/")[2:]
479 if path[0] == "My Drive" and len(path) == 1:
480 return 0
482 if path[0] == "My Drive":
483 drive_id = "root"
484 id = "root"
486 elif path[0] == "Shared drives":
487 drives = drive_service.drives().list().execute()["drives"]
488 for drv in drives:
489 if drv["name"] == path[1]:
490 drive_id = drv["id"]
491 id = drv["id"]
492 break
494 path = path[1:]
496 found = False
497 for d in path:
498 dsf = drive_service.files()
499 args = dict(q=f"'{id}' in parents")
500 if drive_id != "root":
501 args["corpora"] = "drive"
502 args["supportsAllDrives"] = True
503 args["includeItemsFromAllDrives"] = True
504 args["driveId"] = drive_id
506 file_list = dsf.list(**args).execute()
508 found = False
509 for file1 in file_list.get("files", []):
510 if file1["name"] == d:
511 found = True
512 id = file1["id"]
513 break
515 if found:
516 return id
518 else:
519 raise Exception(f"Something went wrong with {path}")
521 else:
522 raise Exception(f"{path} does not seem to be a file or directory")
525def get_link(path):
526 """Return a clickable link for path."""
527 fid = get_id(os.path.abspath(path))
528 drive_service = gdrive()
529 x = (
530 drive_service.files()
531 .get(fileId=fid, supportsAllDrives=True, fields="webViewLink")
532 .execute()
533 )
534 url = x.get("webViewLink", "No web link found")
535 return HTML(f"<a href={url} target=_blank>{path}</a>")
538def gchdir(path=None):
539 """Change working dir to path.
541 if path is None, default to working directory of current notebook.
542 """
543 if path is None:
544 path = os.path.dirname(get_path(current_notebook()[1]))
546 if os.path.isabs(path):
547 os.chdir(path)
548 else:
549 os.chdir(os.path.abspath(path))
552def gdownload(*FILES, **kwargs):
553 """Download files. Each arg can be a path, or pattern.
555 If you have more than one file, a zip is downloaded.
556 You can specify a zip file name as a kwarg:
558 gdownload("*", zip="test.zip")
560 The zip file will be deleted unless you use keep=True as a kwarg.
562 """
563 fd = []
564 for f in FILES:
565 for g in glob.glob(f):
566 fd += [g]
568 if (len(fd) == 1) and (os.path.isfile(fd[0])):
569 files.download(fd[0])
570 else:
571 if "zip" in kwargs:
572 zipfile = kwargs["zip"]
573 else:
574 now = datetime.now()
575 zipfile = now.strftime("%m-%d-%YT%H-%M-%S.zip")
577 if os.path.exists(zipfile):
578 os.unlink(zipfile)
580 s = subprocess.run(
581 ["zip", "-r", zipfile, *fd],
582 stdout=subprocess.PIPE,
583 stderr=subprocess.PIPE,
584 )
585 if s.returncode != 0:
586 print(f"zip did not fully succeed:\n{s.stdout.decode()}\n{s.stderr.decode()}\n")
587 files.download(zipfile)
590# if not kwargs.get('keep', False):
591# os.unlink(zip)
594##################################################################
595# Get to a shell
596##################################################################
597def gconsole():
598 """Open a shell in colab.
600 Adapted from
601 https://github.com/airesearch-in-th/kora/blob/master/kora/console.py
602 """
603 url = (
604 "https://github.com/gravitational/teleconsole/releases/download"
605 "/0.4.0/teleconsole-v0.4.0-linux-amd64.tar.gz"
606 )
607 os.system(f"curl -L {url} | tar xz") # download & extract
608 os.system("mv teleconsole /usr/local/bin/") # in PATH
610 # Set PS1, directory
611 with open("/root/.bashrc", "a") as f:
612 f.write(r'PS1="\e[1;36m\w\e[m# "\n')
613 f.write("cd /content \n")
614 f.write(
615 "PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:"
616 "/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/tools/node/bin:"
617 "/tools/google-cloud-sdk/bin:/opt/bin \n"
618 )
620 process = subprocess.Popen(
621 "teleconsole",
622 shell=True,
623 stdin=subprocess.PIPE,
624 stdout=subprocess.PIPE,
625 stderr=subprocess.PIPE,
626 )
627 for _ in range(6):
628 line = process.stdout.readline()
630 url = line.decode().strip().split()[-1]
631 print("Console URL:", url)
632 return IFrame(url, width=800, height=600)
635##################################################################
636# Fancy outputs
637##################################################################
640def gsuite(fid_or_url, width=1200, height=1000):
641 """Return an iframe that renders the item in a colab."""
642 drive_service = gdrive()
643 if fid_or_url.startswith("http"):
644 url = fid_or_url
645 else:
646 # Assume we have an fid
647 x = (
648 drive_service.files()
649 .get(fileId=fid_or_url, supportsAllDrives=True, fields="webViewLink")
650 .execute()
651 )
652 url = x.get("webViewLink", "No web link found.")
654 display(HTML(f"""<a href="{url}" target="_blank">Link</a><br>"""))
656 g = requests.get(url)
657 xframeoptions = g.headers.get("X-Frame-Options", "").lower()
658 if xframeoptions in ["deny", "sameorigin"]:
659 print(f"X-Frame-Option = {xframeoptions}\nEmbedding in IFrame is not allowed for {url}.")
660 return None
661 else:
662 return IFrame(url, width, height)