JAX: Accelerated machine-learning research via composable function transformations in Python

by · Dec 14, 2019 · 1,846 views ·

Machine learning researchers often express complex models as a program, relying on program transformations to add functionality. New languages and transformations (e.g., TorchScript and TensorFlow AutoGraph) are becoming core capabilities of ML libraries. However, existing transformations, such as automatic differentiation (AD or autodiff), inference in probabilistic programming languages (PPLs), and optimizing compilers are often built in isolation, and limited in scope. This workshop aims at viewing program transformations in ML in a unified light, making these capabilities more accessible, and building entirely new ones. Program transformations are an area of active study. AD transforms a program performing numerical computation into one computing the gradient of those computations. In probabilistic programming, a program describing a sampling procedure can be modified to perform inference on model parameters given observations. Other examples are vectorizing a program expressed on one data point, and learned transformations where ML models use programs as inputs or outputs. This workshop will bring together researchers in the fields of AD, probabilistic programming, programming languages, compilers, and ML, with the goal of understanding the commonalities between disparate approaches and views, and sharing ways to make these techniques broadly available. It would enable ML practitioners to iterate faster on novel models and architectures (e.g., those naturally expressed through high-level constructs like recursion).