Metadata-Version: 2.4
Name: graph_attention_student
Version: 1.3.0
Summary: MEGAN: Multi Explanation Graph Attention Network
Author-email: Jonas Teufel <jonseb1998@gmail.com>
Maintainer-email: Jonas Teufel <jonseb1998@gmail.com>
License-Expression: MIT
Keywords: attention,chemistry,deep learning,explainable ai,graph attention,graph neural network,interpretability,machine learning,molecular property prediction,pytorch
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Chemistry
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: <3.14,>=3.8
Requires-Dist: cairosvg>=2.5.2
Requires-Dist: click>=7.1.2
Requires-Dist: hdbscan>=0.8.33
Requires-Dist: imageio>=2.19.0
Requires-Dist: jinja2>=3.0.3
Requires-Dist: lightning>=2.1.3
Requires-Dist: matplotlib>=3.5.3
Requires-Dist: nltk>=3.7
Requires-Dist: numpy<2.0.0,>=1.22.0
Requires-Dist: orjson>=3.8.0
Requires-Dist: poetry-bumpversion>=0.3.0
Requires-Dist: polars>=1.0.0
Requires-Dist: pycomex>=0.27.0
Requires-Dist: pydyf<0.11.0; python_version < '3.9'
Requires-Dist: rdkit>=2022.9.1
Requires-Dist: seaborn>=0.13.1
Requires-Dist: torch-geometric>=2.4.0
Requires-Dist: torch<2.5.1,>=2.1.2
Requires-Dist: umap-learn>=0.5.3
Requires-Dist: vgd-counterfactuals>=0.1.0
Requires-Dist: visual-graph-datasets>=0.17.0
Requires-Dist: weasyprint>=57.1
Description-Content-Type: text/x-rst

|made-with-python| |made-with-pytorch| |python-version| |os-linux|

.. |os-linux| image:: https://img.shields.io/badge/os-linux-orange.svg
   :target: https://www.python.org/

.. |python-version| image:: https://img.shields.io/badge/Python-3.8.0-green.svg
   :target: https://www.python.org/

.. |made-with-pytorch| image:: https://img.shields.io/badge/Made%20with-PyTorch-orange.svg
   :target: https://pytorch.org/

.. |made-with-python| image:: https://img.shields.io/badge/Made%20with-Python-1f425f.svg
   :target: https://www.python.org/

.. image:: architecture.png
    :width: 800
    :alt: Architecture Overview

👩‍🏫 MEGAN: Multi Explanation Graph Attention Student
======================================================

**Abstract.** Explainable artificial intelligence (XAI) methods are expected to improve trust during human-AI interactions,
provide tools for model analysis and extend human understanding of complex problems. Attention-based models
are an important subclass of XAI methods, partly due to their full differentiability and the potential to
improve explanations by means of explanation-supervised training. We propose the novel multi-explanation
graph attention network (MEGAN). Our graph regression and classification model features multiple explanation
channels, which can be chosen independently of the task specifications. We first validate our model on a
synthetic graph regression dataset, where our model produces single-channel explanations with quality
similar to GNNExplainer. Furthermore, we demonstrate the advantages of multi-channel explanations on one
synthetic and two real-world datasets: The prediction of water solubility of molecular graphs and
sentiment classification of movie reviews. We find that our model produces explanations consistent
with human intuition, opening the way to learning from our model in less well-understood tasks.

🔔 News
-------

- **September 2025** - Version `1.1.0` of the package has finally been released!
- **April 2024** - The follow-up paper about *global concept explanations using an extension of MEGAN* is now available on arxiv: https://arxiv.org/abs/2404.16532 
- **October 2023** - The `paper`_ is published with Springer in the xAI conference proceedings: https://link.springer.com/chapter/10.1007/978-3-031-44067-0_18
- **June 2023** - Check out the `MeganExplains`_ web interface @ https://megan.aimat.science/. The interface allows to query MEGAN models trained on 
  different graph prediction tasks and to visualize the corresponding explanations provided by the model.
- **March 2023** - The `paper`_ was accepted at the `1st xAI world conference <https://xaiworldconference.com/2023/>`_

📦 Package Dependencies
-----------------------

- The package is designed to run in an environment `3.8 <= python <= 3.13`. 
- A graphics card with CUDA support (cuDNN) is recommended for model training.
- A **Linux** operating system is recommended for development.
 
📦 Installation by Package
--------------------------

The package is also published as a library on PyPi and can be installed like this:

.. code-block:: shell

     uv pip install graph_attention_student


📦 Installation from Source
---------------------------

Clone the repository from github:

.. code-block:: shell

    git clone https://github.com/aimat-lab/graph_attention_student

Then in the main folder run a `pip install`:

.. code-block:: shell

    cd graph_attention_student
    uv pip install -e .

.. warning::
   **Note for Windows Users**

   The visualization libraries ``cairosvg`` and ``weasyprint`` require additional system dependencies on Windows.
   Install MSYS2 from https://www.msys2.org/ and run:

   .. code-block:: bash

      pacman -S mingw-w64-ucrt-x86_64-cairo mingw-w64-ucrt-x86_64-gtk3 mingw-w64-ucrt-x86_64-glib2 mingw-w64-ucrt-x86_64-pango

   Add ``C:\msys64\ucrt64\bin`` to your PATH and set environment variable: ``WEASYPRINT_DLL_DIRECTORIES=C:\msys64\ucrt64\bin``


🚀 Quickstart
-------------

The fastest way to train a MEGAN model is using the built-in experiment scripts. Prepare a CSV file with
SMILES strings and target values, then run:

.. code-block:: bash

    # Clone and install
    git clone https://github.com/aimat-lab/graph_attention_student
    cd graph_attention_student
    uv pip install -e .

    # Train a regression model
    python graph_attention_student/experiments/train_model__megan.py \
        --CSV_FILE_PATH='"/path/to/your/data.csv"' \
        --VALUE_COLUMN_NAME='"smiles"' \
        --TARGET_COLUMN_NAMES='["target"]' \
        --DATASET_TYPE='"regression"' \
        --EPOCHS=150

Your CSV should have a ``smiles`` column and your target column(s):

.. code-block:: text

    smiles,target
    CCO,1.23
    CCN,2.45
    CCC,0.89

Key parameters: ``CSV_FILE_PATH`` (path to data), ``TARGET_COLUMN_NAMES`` (prediction target),
``DATASET_TYPE`` ('regression' or 'classification'). See ``train_model__megan.py --help`` for all options.

📄 Config Files
---------------

Instead of passing parameters on the command line, you can create a YAML config file:

.. code-block:: yaml

    # config.yml
    extend: train_model__megan.py
    parameters:
      CSV_FILE_PATH: /path/to/your/data.csv
      TARGET_COLUMN_NAMES:
        - target
      VALUE_COLUMN_NAME: smiles
      DATASET_TYPE: regression
      EPOCHS: 100
      BATCH_SIZE: 64
      LEARNING_RATE: 0.0001

Then run the experiment with:

.. code-block:: bash

    pycomex run graph_attention_student/experiments/config.yml

.. _`GATv2`: https://github.com/tech-srl/how_attentive_are_gats

💻 Command Line Interface
-------------------------

For quick predictions, use the ``megan`` CLI:

.. code-block:: bash

    # Train from CSV
    megan train dataset.csv

    # Make predictions with explanations
    # Optionally pass the path to a model checkpoint to use for the prediction.
    megan predict "CCO"

Use ``megan --help`` for all options.

🤖 Python API
-------------

For custom workflows, use the Python API directly:

.. code-block:: python

    import pytorch_lightning as pl
    from torch_geometric.loader import DataLoader
    from visual_graph_datasets.processing.molecules import MoleculeProcessing
    from graph_attention_student import Megan, SmilesDataset

    # Setup
    processing = MoleculeProcessing()
    dataset = SmilesDataset(
        dataset="data.csv",
        smiles_column='smiles',
        target_columns=['target'],
        processing=processing,
    )
    loader = DataLoader(dataset, batch_size=64)

    # Create model
    model = Megan(
        node_dim=processing.get_num_node_attributes(),
        edge_dim=processing.get_num_edge_attributes(),
        units=[64, 64, 64],
        final_units=[64, 32, 1],
        prediction_mode='regression',
        importance_factor=1.0,
    )

    # Train
    trainer = pl.Trainer(max_epochs=150, accelerator='auto')
    trainer.fit(model, train_dataloaders=loader)
    model.eval()
    model.save("model.ckpt")

**Loading and Using Models:**

.. code-block:: python

    from graph_attention_student import Megan
    from graph_attention_student.torch.advanced import megan_prediction_report

    model = Megan.load("model.ckpt")
    model.eval()

    # Make prediction
    results = model.forward_graph(processing.process("CCO"))
    print(f"Prediction: {results['graph_output'].item():.3f}")

    # Generate explanation PDF
    megan_prediction_report(
        value="CCO",
        model=model,
        processing=processing,
        output_path="report.pdf"
    )

🔍 Examples
-----------

The following examples show some of the *cherry picked* examples that show the explanatory capabilities of
the model.

RB-Motifs Dataset
~~~~~~~~~~~~~~~~~

This is a synthetic dataset, which basically consists of randomly generated graphs with nodes of different
colors. Some of the graphs contain special sub-graph motifs, which are either blue-heavy or red-heavy
structures. The blue-heavy sub-graphs contribute a certain negative value to the overall value of the graph,
while red-heavy structures contain a certain positive value.

This way, every graph has a certain value associated with it, which is between -3 and 3. The network was
trained to predict this value for each graph.

.. image:: rb_motifs_example.png
    :width: 800
    :alt: Rb-Motifs Example

The examples shows from left to right: (1) The ground truth explanations, (2) a baseline MEGAN model trained
only on the prediction task, (3) explanation-supervised MEGAN model and (4) GNNExplainer explanations for a
basic GCN network. While the baseline MEGAN and GNNExplainer focus only on one of the ground truth motifs,
the explanation-supervised MEGAN model correctly finds both.

Water Solubility Dataset
~~~~~~~~~~~~~~~~~~~~~~~~

This is the `AqSolDB`_ dataset, which consists of ~10000 molecules and measured values for the solubility in
water (logS value).

The network was trained to predict the solubility value for each molecule.

.. image:: solubility_example.png
    :width: 800
    :alt: Solubility Example.png

.. _`AqSolDB`: https://www.nature.com/articles/s41597-019-0151-1

Movie Reviews
~~~~~~~~~~~~~

Originally the *MovieReviews* dataset is a natural language processing dataset from the `ERASER`_ benchmark.
The task is to classify the sentiment of ~2000 movie reviews collected from the IMDB database into the
classes "positive" and "negative". This dataset was converted into a graph dataset by considering all words
as nodes of a graph and then connecting adjacent words by undirected edges with a sliding window of size 2.
Words were converted into numeric feature vectors by using a pre-trained `GLOVE`_ model.

Example for a positive review:

.. image:: movie_reviews_pos.png
    :width: 800
    :alt: Positive Movie Review

Example for a negative review:

.. image:: movie_reviews_neg.png
    :width: 800
    :alt: Negative Movie Review

Examples show the explanation channel for the "negative" class left and the "positive" class right.
Sentences with negative / positive adjectives are appropriately attributed to the corresponding channels.

📖 Referencing
--------------

If you use, extend or otherwise mention or work, please cite the `paper`_ as follows:

.. code-block:: bibtex

    @article{teufel2023megan
        title={MEGAN: Multi-Explanation Graph Attention Network},
        author={Teufel, Jonas and Torresi, Luca and Reiser, Patrick and Friederich, Pascal},
        journal={xAI 2023},
        year={2023},
        doi={10.1007/978-3-031-44067-0_18},
        url="\url{https://link.springer.com/chapter/10.1007/978-3-031-44067-0_18\}",
    }


Credits
------------

* **PyTorch Lightning** provides the high-level training framework that powers the modern MEGAN implementation,
  offering easy GPU acceleration, distributed training, and experiment management.
* **PyTorch Geometric** supplies the fundamental graph neural network building blocks and efficient graph data handling
  that enable MEGAN's attention mechanisms and message passing operations.
* VisualGraphDataset_ is a library which aims to establish a special dataset format specifically for graph
  XAI applications with the aim of streamlining the visualization of graph explanations and to make them
  more comparable by packaging canonical graph visualizations directly with the dataset.
* PyComex_ is a micro framework which simplifies the setup, processing and management of computational
  experiments. It is also used to auto-generate the command line interface that can be used to interact
  with these experiments.

.. _PyComex: https://github.com/the16thpythonist/pycomex
.. _VisualGraphDataset: https://github.com/aimat-lab/visual_graph_datasets
.. _MEGAN: https://github.com/aimat-lab/graph_attention_student

.. _`ERASER`: https://www.eraserbenchmark.com/
.. _`GLOVE`: https://nlp.stanford.edu/projects/glove/

.. _`paper`: https://link.springer.com/chapter/10.1007/978-3-031-44067-0_18
.. _`poetry`: https://python-poetry.org/
.. _`MeganExplains`: https://megan.aimat.science/ 
.. _`visual_graph_dataset`: https://github.com/aimat-lab/visual_graph_datasets 