SFU course 'Deep learning with JAX - Part 2'
Date: 10 April 2025 @ 17:00 - 19:00
RegisterJAX is a fast open-source Python library for function transformations (including differentiation) and array computations on accelerators (GPUs/TPUs). These attributes make it ideal for deep learning, but JAX is not, in itself, a deep learning library: it provides a structural framework on which libraries can be built without providing domain-specific tooling.To make full use of JAX's flexible autodiff and enhanced efficiency for deep learning while maintaining a syntax familiar to PyTorch users, a solid approach is to use TorchData, TensorFlow Datasets, Grain, or Hugging Face Datasets to load the data, Flax to build neural networks, Optax for optimization, and Orbax for checkpointing.This introductory course does not require any prior knowledge.
Keywords: GPU, HPC, Machine Learning, AI, Python, Programming
Venue: online
Activity log