Metadata-Version: 2.4
Name: digeo
Version: 0.0.2
Summary: A library for differentiable geometry processing.
Author: Hippolyte Verninas
License-Expression: MIT
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: numpy
Requires-Dist: trimesh
Requires-Dist: scipy
Requires-Dist: robust_laplacian
Dynamic: license-file

# DiGeo

DiGeo is python package to trace differentiable geodesics (compute the exponential map) on triangulated meshes.

DiGeo is implemented with PyTorch and CUDA, and supports batched inputs. It is designed to be efficient and scalable, and can handle large meshes. It supports both single and double precision floating point numbers, and can be used on both CPU and GPU. The gradients are computed either using parallel transport or finite differences (more costly but more precise).


<p align="center">
    <img src="img/geodesic_traces_bunny.png" height="400"/>
</p>

## Installation

Requirements:
- Python 3.8+
- CUDA toolkit
- PyTorch 1.10+
- NumPy
- Trimesh


Install the package with:
```bash
pip install --no-build-isolation -e "git+ssh://git@github.com/Etyl/digeo.git"
```

## Usage

```python
from digeo import (
    load_mesh_from_file,
    load_mesh_from_trimesh,
    trace_geodesics,
    uniform_sampling,
    MeshPointBatch
)

# Load a mesh from file
mesh = load_mesh_from_file("path/to/mesh.obj").to(device)
# Or load directly from trimesh object
mesh = load_mesh_from_trimesh(trimesh_object).to(device)

# Sample points on the mesh
meshpoints = uniform_sampling(mesh, num_samples).to(device)
# Or provide your own points on the mesh
meshpoints = MeshPointBatch(
    faces=faces,
    uvs=uvs,
).to(device)

# Define initial directions
directions = torch.randn((num_samples, 3), device=device)

# Trace geodesics
end_meshpoints, geodesic_info = trace_geodesics(
    mesh,
    meshpoints,
    directions,
    gradient="finite_difference",  # or "parallel_transport"
    debug=True,  # Optional, to save the full paths (requires additional memory)
)

# end_meshpoints is a MeshPointBatch which can be converted to 3D points tensor
end_points = end_meshpoints.interpolate(mesh)

# We can also compute a loss and backpropagate
loss = some_loss_function(end_points)
loss.backward()

# We can access the full paths, directions and normals at each step if debug=True
path = geodesic_info.get_path(k) # To get the trace of the k-th geodesic
path_directions = geodesic_info.get_directions(k)
path_normals = geodesic_info.get_normals(k)
```

It is also possible to get the rotation matrices used for parallel transport by setting `save_parallel_transport=True` in `trace_geodesics`.
```python
end_meshpoints, geodesic_info = trace_geodesics(
    mesh,
    meshpoints,
    directions,
    gradient="finite_difference",  # or "parallel_transport"
    save_parallel_transport=True,  # Save the rotation matrices for parallel transport
)

rotation_matrices = geodesic_info.rotation  # Shape (num_samples, 3, 3)
```
