backprop

Heterogeneous, type-safe automatic backpropagation in Haskell

https://github.com/mstksg/backprop

Version on this page:0.0.3.0
LTS Haskell 22.13:0.2.6.5
Stackage Nightly 2024-03-14:0.2.6.5
Latest on Hackage:0.2.6.5

See all snapshots backprop appears in

BSD-3-Clause licensed by Justin Le
Maintained by [email protected]
This version can be pinned in stack with:backprop-0.0.3.0@sha256:54e62773bcc423aa829ca19d8fd2d1d852467d13496f7b955fdbbe91d5d0408f,3902

backprop

backprop on Hackage Build Status

Literate Haskell Tutorial/Demo on MNIST data set (and PDF rendering)

Automatic heterogeneous back-propagation that can be used either implicitly (in the style of the ad library) or using explicit graphs built in monadic style. Implements reverse-mode automatic differentiation. Differs from ad by offering full heterogeneity – each intermediate step and the resulting value can have different types. Mostly intended for usage with tensor manipulation libraries to implement automatic back-propagation for gradient descent and other optimization techniques.

Currently up on hackage (with 100% documentation coverage), but more up-to-date documentation is currently rendered on github pages!

At the moment this project is in pre-alpha (v0.0.1.0), and is published/put up on Hackage as a call for comments and thoughts. It has 100% documentation coverage at the moment. Performance was not yet a priority before this, but will be from now on. (Previously, highest priority was API/usability). See the todos section for more information on what’s missing, and how one would be able to contribute!

MNIST Digit Classifier Example

Tutorial and example on training on the MNIST data set available here as a literate haskell file, or rendered here as a PDF! Read this first!

The literate haskell file is a standalone haskell file that you can compile (preferably with -O2) on its own with stack or some other dependency manager. It can also be compiled with the build script in the project directory (if stack is installed, and appropriate dependencies are installed), using

$ ./Build.hs exe

Brief example

The quick example below describes the running of a neural network with one hidden layer to calculate its squared error with respect to target targ, which is parameterized by two weight matrices and two bias vectors. Vector/matrix types are from the hmatrix package.

logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x))

matVec
    :: (KnownNat m, KnownNat n)
    => Op '[ L m n, R n ] (R m)

neuralNetImplicit
      :: (KnownNat m, KnownNat n, KnownNat o)
      => R m
      -> BPOpI s '[ L n m, R n, L o n, R o ] (R o)
neuralNetImplicit inp = \(w1 :< b1 :< w2 :< b2 :< Ø) ->
    let z = logistic (liftB2 matVec w1 x + b1)
    in  logistic (liftB2 matVec w2 z + b2)
  where
    x = constRef inp

neuralNetExplicit
      :: (KnownNat m, KnownNat n, KnownNat o)
      => R m
      -> BPOp s '[ L n m, R n, L o n, R o ] (R o)
neuralNetExplicit inp = withInps $ \(w1 :< b1 :< w2 :< b2 :< Ø) -> do
    y1  <- matVec ~$ (w1 :< x1 :< Ø)
    let x2 = logistic (y1 + b1)
    y2  <- matVec ~$ (w2 :< x2 :< Ø)
    return $ logistic (y2 + b2)
  where
    x1 = constVar inp

Now neuralNetExplicit and neuralNetImplicit can be “run” with the input vectors and parameters (a L n m, R n, L o n, and R o) and calculate the output of the neural net.

runNet
    :: (KnownNat m, KnownNat n, KnownNat o)
    => R m
    -> Tuple '[ L n m, R n, L o n, R o ]
    -> R o
runNet inp = evalBPOp (neuralNetExplicit inp)

But, in defining neuralNet, we also generated a graph that backprop can use to do back-propagation, too!

dot :: KnownNat n
    => Op '[ R n  , R n ] Double

netGrad
    :: forall m n o. (KnownNat m, KnownNat n, KnownNat o)
    => R m
    -> R o
    -> Tuple '[ L n m, R n, L o n, R o ]
    -> Tuple '[ L n m, R n, L o n, R o ]
netGrad inp targ params = gradBPOp opError params
  where
    -- calculate squared error, in *explicit* style
    opError :: BPOp s '[ L n m, R n, L o n, R o ] Double
    opError = do
        res <- neuralNetExplicit inp
        err <- bindRef (res - t)
        dot ~$ (err :< err :< Ø)
      where
        t = constRef targ

The result is the gradient of the input tuple’s components, with respect to the Double result of opError (the squared error). We can then use this gradient to do gradient descent.

For a more fleshed out example, see the MNIST tutorial (also rendered as a pdf)

Benchmarks

The current version isn’t optimized, but here are some basic benchmarks comparing the library’s automatic differentiation process to “manual” differentiation by hand. When using the MNIST tutorial as an example:

benchmarks

Calculating the gradient using backprop and calculating it by hand (by manual symbolic differentiation) are within an order of magnitude of each-other, time-wise. Using the backprop library takes about 6.5x as long in this case.

However, a full update step (calculate the gradient and update the neural net) adds a lot of constant costs, so for a full training step, the backprop library takes only 2.7x as long as manual symbolic differentation.

This means using this library only slows down your program by a factor of about 2.5x, compared to using only hmatrix.

It’s still definitely not ideal that more than half of the computation time is overhead from the library, but this is just where we stand at the moment. Optimization is just now starting!

Note that at the moment, simply running the network is only slightly slower when using backprop.

Todo

  1. Profiling, to gauge where the overhead comes from (compared to “manual” back-propagation) and how to bring it down.

  2. Some simple performance and API tweaks that are probably possible now and would clearly benefit: (if you want to contribute)

    a. ~~Providing optimized Num/Fractional/Floating instances for BVal by supplying known gradients directly instead of relying on ad.~~ (Now finished, since b3898ae)

b.  Switch from `ST s` to `IO`, and use `unsafePerformIO` to automatically
    bind `BVal`s (like *ad* does) when using `liftB`.  This might remove
    some overhead during graph building, and, from an API standpoint,
    remove the need for explicit binding.

c.  Switch from `STRef`s/`IORef`s to `Array`.  (This one I'm unclear if it
    would help any)
  1. Benchmark against competing back-propagation libraries like ad, and auto-differentiating tensor libraries like grenade
  1. Explore opportunities for parallelization. There are some naive ways of directly parallelizing right now, but potential overhead should be investigated.

  2. Some open questions:

    a. Is it possible to offer pattern matching on sum types/with different constructors for implicit-graph backprop? It’s possible for explicit-graph versions already, with choicesVar, but not yet with the implicit-graph interface. Could be similar to an “Applicative vs. Monad” issue where you can only have pre-determined fixed computation paths when using Applicative, but I’m not sure. Still, it would be nice, because if this was possible, we could possibly do away with explicit-graph mode completely.

    b. Though we already have safe sum type support with explicit-graph mode, we can’t support GADTs yet safely. It’d be nice to see if this is possible, because a lot of dependently typed neural network stuff is made much simpler with GADTs.

    As of v0.0.3.0, we have a way of dealing with GADTs in explicit-graph mode (using withGADT) that is unsafe, and requires some ugly manual plumbing by the user that could potentially be confusing. But it would still be nice to have a way that is safe and doesn’t require the manual plumbing and isn’t as easy to mess up.

Changes

Changelog

Version 0.0.3.0

https://github.com/mstksg/backprop/releases/tag/v0.0.3.0

  • Removed samples as registered executables in the cabal file, to reduce dependences to a bare minimum. For convenience, build script now also compiles the samples into the local directory if stack is installed.

  • Added experimental (unsafe) combinators for working with GADTs with existential types, withGADT, to Numeric.Backprop module.

  • Fixed broken links in Changelog.

Version 0.0.2.0

https://github.com/mstksg/backprop/releases/tag/v0.0.2.0

  • Added optimized numeric Ops, and re-write Num/Fractional/Floating instances in terms of them.

  • Removed all traces of Summer/Unity from the library, eliminating a whole swath of “explicit-Summer”/“explicit-Unity” versions of functions. As a consequence, the library now only works with Num instances. The API, however, is now much more simple.

  • Benchmark suite added for MNIST example.

Version 0.0.1.0

https://github.com/mstksg/backprop/releases/tag/v0.0.1.0

  • Initial pre-release, as a request for comments. API is in a usable form and everything is fully documented, but there are definitely some things left to be done. (See README.md)