Jraph: A library for Graph Neural Networks in Jax
Project description
Jraph  A library for graph neural networks in jax.
Jraph (pronounced "giraffe") is a lightweight library for working with graph neural networks in jax. It provides a data structure for graphs, a set of utilities for working with graphs, and a 'zoo' of forkable graph neural network models.
Installation
pip install jraph
Or Jraph can be installed directly from github using the following command:
pip install git+git://github.com/deepmind/jraph.git
The examples require additional dependencies. To install them please run:
pip install "jraph[examples, ogb_examples] @ git+git://github.com/deepmind/jraph.git"
Overview
Jraph is designed to provide utilities for working with graphs in jax, but doesn't prescribe a way to write or develop graph neural networks.
graph.py
provides a lightweight data structure,GraphsTuple
, for working with graphs.utils.py
provides utilities for working withGraphsTuples
in jax. Utilities for batching datasets of
GraphsTuples
.  Utilities to support jit compilation of variable shaped graphs via padding and masking.
 Utilities for defining losses on partitions of inputs.
 Utilities for batching datasets of
models.py
provides examples of different types of graph neural network message passing. These are designed to be lightweight, easy to fork and adapt. They do not manage parameters for you  for that, consider usinghaiku
orflax
. See the examples for more details.
Quick Start
Jraph takes inspiration from the Tensorflow graph_nets library in defining a GraphsTuple
data structure, which is a namedtuple that contains one or more directed graphs.
Representing Graphs  The GraphsTuple
import jraph import jax.numpy as jnp # Define a three node graph, each node has an integer as its feature. node_features = jnp.array([[0.], [1.], [2.]]) # We will construct a graph for which there is a directed edge between each node # and its successor. We define this with `senders` (source nodes) and `receivers` # (destination nodes). senders = jnp.array([0, 1, 2]) receivers = jnp.array([1, 2, 0]) # You can optionally add edge attributes. edges = jnp.array([[5.], [6.], [7.]]) # We then save the number of nodes and the number of edges. # This information is used to make running GNNs over multiple graphs # in a GraphsTuple possible. n_node = jnp.array([3]) n_edge = jnp.array([3]) # Optionally you can add `global` information, such as a graph label. global_context = jnp.array([[1]]) # Same feature dimensions as nodes and edges. graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)
A GraphsTuple
can have more than one graph.
two_graph_graphstuple = jraph.batch([graph, graph])
The node and edge features are stacked on the leading axis.
jraph.batch([graph, graph]).nodes >>> DeviceArray([[0.], [1.], [2.], [0.], [1.], [2.]], dtype=float32)
You can tell which nodes are from which graph by looking at n_node
.
jraph.batch([graph, graph]).n_node >>> DeviceArray([3, 3], dtype=int32)
You can store nests of features in nodes
, edges
and globals
. This makes
it possible to store multiple sets of features for each node, edge or graph, with
potentially different types and semantically different meanings (for example
'training' and 'testing' nodes). The only requirement if that all arrays within
each nest must have a common leading dimensions size, matching the total number
of nodes, edges or graphs within the Graphstuple
respectively.
node_targets = jnp.array([[True], [False], [True]]) graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets})
Using the Model Zoo
Jraph provides a set of implemented reference models for you to use.
A Jraph model defines a message passing algorithm between the nodes, edges and
global attributes of a graph. The user defines update
functions that update graph features, which are typically neural networks but can be arbitrary jax functions.
Let's go through a GraphNetwork
(paper) example.
A GraphNet's first update function updates the edges using edge
features,
the node features of the sender
and receiver
and the global
features.
# As one example, we just pass the edge features straight through. def update_edge_fn(edge, sender, receiver, globals_): return edge
Often we use the concatenation of these features, and jraph
provides an easy
way of doing this with the concatenated_args
decorator.
@jraph.concatenated_args def update_edge_fn(concatenated_features): return concatenated_features
Typically, a learned model such as a MultiLayer Perceptron is used within an update function.
The user similarly defines functions that update the nodes and globals. These
are then used to configure a GraphNetwork
. To see the arguments to the node
and global update_fns
please take a look at the model zoo.
net = jraph.GraphNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_global_fn)
net
is a function that sends messages according to the GraphNetwork
algorithm
and applies the update_fn
. It takes a graph, and returns a graph.
updated_graph = net(graph)
Examples
For a deeper dive best place to start are the examples. In particular:
examples/basic.py
provides an introduction to the features of the library.ogb_examples/train.py
provides an end to end example of training aGraphNet
onmolhiv
Open Graph Benchmark dataset. Please note, you need to have downloaded the dataset to run this example.
The rest of the examples are short scripts demonstrating how to use various
models from our model zoo, as well as making models go fast with jax.jit
, and
how to deal with Jax's static shape requirement.
Citing Jraph
To cite this repository:
@software{jraph2020github,
author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro SanchezGonzalez},
title = {{J}raph: {A} library for graph neural networks in jax.},
url = {http://github.com/deepmind/jraph},
version = {0.0.1.dev},
year = {2020},
}
Project details
Release history Release notifications  RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Filename, size  File type  Python version  Upload date  Hashes 

Filename, size jraph0.0.2.dev0py3noneany.whl (75.9 kB)  File type Wheel  Python version py3  Upload date  Hashes View 
Filename, size jraph0.0.2.dev0.tar.gz (58.7 kB)  File type Source  Python version None  Upload date  Hashes View 
Hashes for jraph0.0.2.dev0py3noneany.whl
Algorithm  Hash digest  

SHA256  eb6cd33ef2b07bdd57e2030da9faeeaec422415631f61812910ba673cb8e9c3d 

MD5  9a54ddcf84c519db425ef6c8ae0bcaf1 

BLAKE2256  ba46fb2501fa00eda4045f9744fb51dfe87af48dcd06fd37da03ee2be9eb2e4f 