While Julia is great, there are still a lot of existing useful differentiable Python code in PyTorch, Jax, etc. Given PyCall.jl is already so great and seamless, one might wonder what it takes to differentiate through those calls to Python functions. PyCallChainRules.jl aims for that ideal. DLPack.jl is leveraged to pass CPU or GPU arrays without any copy between Julia and Python.
Auto differentiation interfaces are rapidly converging with functorch
and jax
on the Python side and more explicit interfaces for dealing with gradients in Julia with Functors.jl
and Optimisers.jl as well as more explicit machine learning layers in Lux.jl. While it is relatively easy to implement new functionality in Julia, Python still remains the standard interface layer for most state-of-the-art functionality especially for GPU kernels. Even if not the most performant, there is value in being able to call existing differentiable functions in Python, as a developer slowly implements equivalent functionality in Julia.