{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hierarchical Time Series\n",
    "\n",
    "Many real-world datasets are naturally organised as trees: a country's electricity consumption breaks into regions, which break into cities.  \n",
    "`HierarchicalTimeSeries` lets you model that structure directly and **aggregate bottom-up** through the tree.\n",
    "\n",
    "| Class | Purpose |\n",
    "| --- | --- |\n",
    "| `HierarchyNode` | A single node — key, level, children, and an optional `TimeSeriesList` |\n",
    "| `HierarchicalTimeSeries` | The tree container — traversal, aggregation, conversion |\n",
    "| `AggregationMethod` | `SUM`, `MEAN`, `MIN`, `MAX` |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from datetime import datetime, timedelta, timezone\n\nimport numpy as np\n\nimport timedatamodel as tdm\n\nbase = datetime(2024, 1, 15, tzinfo=timezone.utc)\ntimestamps = [base + timedelta(hours=i) for i in range(24)]\nrng = np.random.default_rng(42)"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create leaf time series\n",
    "\n",
    "Each leaf in the hierarchy holds a `TimeSeriesList`. Here we model electricity consumption for five Norwegian cities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "def make_consumption(name: str, base_mw: float) -> tdm.TimeSeriesList:\n    pattern = base_mw * (1 + 0.3 * np.sin(np.linspace(0, 2 * np.pi, 24)))\n    noise = rng.normal(0, base_mw * 0.05, 24)\n    return tdm.TimeSeriesList(\n        tdm.Frequency.PT1H,\n        timezone=\"Europe/Oslo\",\n        timestamps=timestamps,\n        values=(pattern + noise).tolist(),\n        name=name,\n        unit=\"MW\",\n    )\n\nts_oslo = make_consumption(\"Oslo\", 500)\nts_bergen = make_consumption(\"Bergen\", 200)\nts_stavanger = make_consumption(\"Stavanger\", 150)\nts_tromsoe = make_consumption(\"Tromsø\", 80)\nts_bodoe = make_consumption(\"Bodø\", 50)"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a hierarchy with HierarchyNode\n",
    "\n",
    "Construct the tree by nesting `HierarchyNode` objects. Leaves hold a `TimeSeriesList`; interior nodes have `children`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "root = tdm.HierarchyNode(\n    key=\"Norway\",\n    level=\"country\",\n    children=[\n        tdm.HierarchyNode(\n            key=\"South\",\n            level=\"region\",\n            children=[\n                tdm.HierarchyNode(key=\"Oslo\", level=\"city\", timeseries=ts_oslo),\n                tdm.HierarchyNode(key=\"Bergen\", level=\"city\", timeseries=ts_bergen),\n                tdm.HierarchyNode(key=\"Stavanger\", level=\"city\", timeseries=ts_stavanger),\n            ],\n        ),\n        tdm.HierarchyNode(\n            key=\"North\",\n            level=\"region\",\n            children=[\n                tdm.HierarchyNode(key=\"Tromsø\", level=\"city\", timeseries=ts_tromsoe),\n                tdm.HierarchyNode(key=\"Bodø\", level=\"city\", timeseries=ts_bodoe),\n            ],\n        ),\n    ],\n)\n\nhierarchy = tdm.HierarchicalTimeSeries(\n    root,\n    name=\"Norway Consumption\",\n    description=\"Hourly electricity consumption by city\",\n    levels=[\"country\", \"region\", \"city\"],\n    aggregation=tdm.AggregationMethod.SUM,\n)\nhierarchy"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inspecting the tree\n",
    "\n",
    "Basic properties tell you the shape of the hierarchy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Name:     {hierarchy.name}\")\n",
    "print(f\"Levels:   {hierarchy.levels}\")\n",
    "print(f\"# levels: {hierarchy.n_levels}\")\n",
    "print(f\"# nodes:  {hierarchy.n_nodes}\")\n",
    "print(f\"# leaves: {hierarchy.n_leaves}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Navigating the tree\n",
    "\n",
    "`get_node(*path)` walks by key. `get_level(name)` returns all nodes at a given level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "south = hierarchy.get_node(\"South\")\n",
    "print(f\"Node:     {south.key}\")\n",
    "print(f\"Level:    {south.level}\")\n",
    "print(f\"Is leaf:  {south.is_leaf}\")\n",
    "print(f\"Children: {[c.key for c in south.children]}\")\n",
    "print(f\"Leaves:   {south.leaf_count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oslo_node = hierarchy.get_node(\"South\", \"Oslo\")\n",
    "print(f\"Node:       {oslo_node.key}\")\n",
    "print(f\"Is leaf:    {oslo_node.is_leaf}\")\n",
    "print(f\"Has series: {oslo_node.timeseries is not None}\")\n",
    "print(f\"Path:       {oslo_node.path}\")\n",
    "print(f\"Depth:      {oslo_node.depth}\")\n",
    "print(f\"Siblings:   {[s.key for s in oslo_node.siblings]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "city_nodes = hierarchy.get_level(\"city\")\n",
    "print(\"City-level nodes:\")\n",
    "for node in city_nodes:\n",
    "    print(f\"  {node.key} — {len(node.timeseries)} data points\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Leaves and walking\n",
    "\n",
    "`leaves()` returns all leaf nodes. `walk()` yields nodes in pre-order (default) or post-order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"All leaves:\")\n",
    "for leaf in hierarchy.leaves():\n",
    "    print(f\"  {leaf.key:12s}  path={leaf.path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Pre-order walk:\")\n",
    "for node in hierarchy.walk(order=\"pre\"):\n",
    "    indent = \"  \" * node.depth\n",
    "    label = f\"{node.key} [{node.level}]\"\n",
    "    if node.is_leaf:\n",
    "        label += f\" — {len(node.timeseries)} pts\"\n",
    "    print(f\"{indent}{label}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bottom-up aggregation\n",
    "\n",
    "`aggregate()` recursively combines leaf series using the chosen method (default: `SUM`).  \n",
    "Calling it on the root gives the total for the whole hierarchy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total = hierarchy.aggregate()\n",
    "print(f\"Name:   {total.name}\")\n",
    "print(f\"Length: {len(total)} data points\")\n",
    "print(f\"Mean:   {np.nanmean(total.arr):.1f} MW\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "south_total = hierarchy.aggregate(south)\n",
    "print(f\"South region total — mean: {np.nanmean(south_total.arr):.1f} MW\")\n",
    "\n",
    "north = hierarchy.get_node(\"North\")\n",
    "north_total = hierarchy.aggregate(north)\n",
    "print(f\"North region total — mean: {np.nanmean(north_total.arr):.1f} MW\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Level-wise aggregation\n",
    "\n",
    "`aggregate_level(level)` aggregates every node at the named level, returning a dict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "region_agg = hierarchy.aggregate_level(\"region\")\n",
    "\n",
    "for name, ts in region_agg.items():\n",
    "    print(f\"{name:8s}  mean={np.nanmean(ts.arr):7.1f} MW  max={np.nanmax(ts.arr):7.1f} MW\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choosing an aggregation method\n",
    "\n",
    "Override the default method by passing a different `AggregationMethod`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "for method in tdm.AggregationMethod:\n    agg = hierarchy.aggregate(method=method)\n    vals = agg.arr\n    print(f\"{method.value:5s}  mean={np.nanmean(vals):7.1f}  min={np.nanmin(vals):7.1f}  max={np.nanmax(vals):7.1f}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Subtree extraction\n",
    "\n",
    "`subtree(*path)` creates a new `HierarchicalTimeSeries` rooted at the specified node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "south_tree = hierarchy.subtree(\"South\")\n",
    "print(south_tree)\n",
    "print(f\"\\nLevels: {south_tree.levels}\")\n",
    "print(f\"Leaves: {south_tree.n_leaves}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Converting to other containers\n",
    "\n",
    "Flatten the hierarchy into a `TimeSeriesCollection` or `TimeSeriesTable`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collection = hierarchy.to_collection()\n",
    "print(f\"Leaf-level collection: {list(collection.keys())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collection_regions = hierarchy.to_collection(level=\"region\")\n",
    "print(\"Region-level collection (aggregated):\")\n",
    "for key, ts in collection_regions.items():\n",
    "    print(f\"  {key}: {len(ts)} pts, mean={np.nanmean(ts.arr):.1f} MW\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = hierarchy.to_table()\n",
    "print(f\"Table shape: {len(table)} rows × {table.n_columns} columns\")\n",
    "print(f\"Columns: {table.names}\")\n",
    "table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building from a DataFrame\n",
    "\n",
    "`from_dataframe` builds the tree from a long-format DataFrame with hierarchy columns.  \n",
    "Each unique combination of level columns becomes a leaf."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "rows = []\n",
    "for ts_dt in timestamps:\n",
    "    for region, cities in [(\"South\", [\"Oslo\", \"Bergen\"]), (\"North\", [\"Tromsø\", \"Bodø\"])]:\n",
    "        for city in cities:\n",
    "            rows.append({\n",
    "                \"timestamp\": ts_dt,\n",
    "                \"region\": region,\n",
    "                \"city\": city,\n",
    "                \"consumption_mw\": float(rng.normal(200, 30)),\n",
    "            })\n",
    "\n",
    "df = pd.DataFrame(rows)\n",
    "print(f\"DataFrame shape: {df.shape}\")\n",
    "df.head(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "h_from_df = tdm.HierarchicalTimeSeries.from_dataframe(\n    df,\n    level_columns=[\"region\", \"city\"],\n    value_column=\"consumption_mw\",\n    timestamp_column=\"timestamp\",\n    name=\"Consumption from DataFrame\",\n    frequency=tdm.Frequency.PT1H,\n    timezone=\"Europe/Oslo\",\n)\nh_from_df"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_df = h_from_df.aggregate()\n",
    "print(f\"Total from DataFrame hierarchy: mean={np.nanmean(total_df.arr):.1f} MW\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Another example: energy production by source\n",
    "\n",
    "Hierarchies can model any tree-shaped relationship — here, power production broken down by energy source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "energy_root = tdm.HierarchyNode(\n    key=\"Total\",\n    level=\"total\",\n    children=[\n        tdm.HierarchyNode(\n            key=\"Wind\",\n            level=\"source\",\n            children=[\n                tdm.HierarchyNode(key=\"Farm A\", level=\"farm\", timeseries=make_consumption(\"Farm A\", 100)),\n                tdm.HierarchyNode(key=\"Farm B\", level=\"farm\", timeseries=make_consumption(\"Farm B\", 80)),\n            ],\n        ),\n        tdm.HierarchyNode(\n            key=\"Solar\",\n            level=\"source\",\n            children=[\n                tdm.HierarchyNode(key=\"Plant X\", level=\"farm\", timeseries=make_consumption(\"Plant X\", 60)),\n            ],\n        ),\n    ],\n)\n\nenergy = tdm.HierarchicalTimeSeries(\n    energy_root,\n    name=\"Energy Production\",\n    levels=[\"total\", \"source\", \"farm\"],\n    aggregation=tdm.AggregationMethod.SUM,\n)\nenergy"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_agg = energy.aggregate_level(\"source\")\n",
    "for name, ts in source_agg.items():\n",
    "    print(f\"{name:8s}  mean={np.nanmean(ts.arr):.1f} MW\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sequence protocol\n",
    "\n",
    "`HierarchicalTimeSeries` supports `len`, `in`, and bracket indexing with slash-separated paths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Total nodes:        {len(hierarchy)}\")\n",
    "print(f\"'Oslo' in tree:     {'Oslo' in hierarchy}\")\n",
    "print(f\"'Helsinki' in tree: {'Helsinki' in hierarchy}\")\n",
    "\n",
    "node = hierarchy[\"South/Oslo\"]\n",
    "print(f\"\\nBracket access:     {node.key} (leaf={node.is_leaf})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Feature | API |\n",
    "| --- | --- |\n",
    "| Build manually | `HierarchyNode(key, level, children, timeseries)` |\n",
    "| Build from DataFrame | `HierarchicalTimeSeries.from_dataframe(df, level_columns, value_column)` |\n",
    "| Build from dict | `HierarchicalTimeSeries.from_dict(tree, series_map, levels=...)` |\n",
    "| Navigate | `get_node(*path)`, `get_level(name)`, `leaves()` |\n",
    "| Walk | `walk(order=\"pre\")` / `walk(order=\"post\")` |\n",
    "| Aggregate | `aggregate(node, method)` — bottom-up recursion |\n",
    "| Level aggregate | `aggregate_level(level)` → `dict[str, TimeSeriesList]` |\n",
    "| Subtree | `subtree(*path)` → new `HierarchicalTimeSeries` |\n",
    "| Convert | `to_collection(level)`, `to_table(level)` |\n",
    "| Sequence ops | `len(h)`, `key in h`, `h[\"path/to/node\"]` |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}