# Jaxpr-Viz

JAX Computation Graph Visualisation Tool

JAX has built-in functionality to visualise the
HLO graph generated by JAX, but I've found this rather
low-level for some use-cases.

The intention of this package is to visualise how
sub-functions are connected in JAX programs. It does
this by converting the [JaxPr](https://jax.readthedocs.io/en/latest/jaxpr.html)
representation into a pydot graph. See [here](.github/docs/gallery.md)
for examples.

> **NOTE:** This project is still at an early stage and may not
> support all JAX functionality (or permutations thereof). If you spot
> some strange behaviour please create a [Github issue](https://github.com/zombie-einstein/jaxpr-viz/issues).

## Installation

Install with pip:

```bash
pip install jpviz
```

Dependent on your system you may also need to install [Graphviz](https://www.graphviz.org/)

## Usage

Jaxpr-viz can be used to visualise jit compiled (and nested)
functions. It wraps jit compiled functions, which when called
with concrete values returns a [pydot](https://github.com/pydot/pydot)
graph.

For example this simple computation graph

```python
import jax
import jax.numpy as jnp

import jpviz

@jax.jit
def foo(x):
    return 2 * x

@jax.jit
def bar(x):
    x = foo(x)
    return x - 1

# Wrap function and call with concrete arguments
#  here dot_graph is a pydot object
dot_graph = jpviz.draw(bar)(jnp.arange(10))
# This renders the graph to a png file
dot_graph.write_png("computation_graph.png")
```

produces this image

![bar computation graph](.github/images/bar_collapsed.png)

Pydot has a number of options for rendering graphs, see
[here](https://github.com/pydot/pydot#output).

> **NOTE:** For sub-functions to show as nodes/sub-graphs they
> need to be marked with `@jax.jit`, otherwise they will just
> merged into thir parent graph.

### Jupyter Notebook

To show the rendered graph in a jupyter notebook you can use the
helper function `view_pydot`

```python
...
dot_graph = jpviz.draw(bar)(jnp.arange(10))
jpviz.view_pydot(dot)
```

### Visualisation Options

#### Collapse Nodes
By default, functions that are composed of only primitive functions
are collapsed into a single node (like `foo` in the above example).
The full computation graph can be rendered using the `collapse_primitives`
flag, setting it to `False` in the above example

```python
...
dot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))
...
```

produces

![bar computation graph](.github/images/bar_expanded.png)

#### Show Types

By default, type information is included in the node labels, this
can be hidden using the `show_avals` flag, setting it to `False`

```python
...
dot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))
...
```

produces

![bar computation graph](.github/images/bar_no_types.png "Title")

> **NOTE:** The labels of the nodes don't currently correspond
> to argument/variable names in the original Python code. Since
> JAX unpacks arguments/outputs to tuples they do correspond
> to the positioning of arguments and outputs.

## Examples

See [here](.github/docs/gallery.md) for more examples of rendered computation graphs.

## Developers

Developer notes can be found [here](.github/docs/developers.md).
