MathIsimple
Deep Learning
12 min read

Neural Network Training: A Step-by-Step Computation Graph Example

Tracing forward propagation, branching gradients, and parameter updates by hand.

BackpropagationComputation GraphGradient DescentMath Intuition

Many tutorials introduce backpropagation by diving straight into matrix calculus. But before memorizing the formulas for deep linear layers and cross-entropy, it helps to walk through a small, concrete computation graph by hand. Tracing a single forward and backward pass reveals the mechanics of branching gradients and regularization without losing them in the notation.

In this example, we will track four parameters a,b,c,da, b, c, d through a small artificial network across two parameter update steps.

The computation graph and the constants

To keep the arithmetic clean, we use a simple scalar input and a set of elementary functions. The given constants for our training steps are:

  • Input: x=2x = 2
  • Target label: y=1y = 1
  • Regularization coefficient: λ=0.1\lambda = 0.1
  • Learning rate: η=0.1\eta = 0.1

The forward pass is defined by the following sequence of operations:

u=ax+bu = a x + b (Linear transform)
h=u2h = u^2 (Nonlinear activation)
v=ch+u+dv = c h + u + d (Second linear branch)
p=11+vp = \frac{1}{1+v} (Output function)
L=(py)2L = (p - y)^2 (Data loss)
R=λ2(a2+c2)R = \frac{\lambda}{2}(a^2+c^2) (L2 Regularization)
J=L+RJ = L + R (Total objective)

Notice the equation v=ch+u+dv = c h + u + d. The variable uu flows into vv through two paths: directly, and indirectly via hh. This is a classic "branch and merge" in a computation graph. During backpropagation, the gradients along these two paths must be added together.

Backpropagation: Mapping the paths

Before calculating derivatives, it is useful to map out the routes from the final objective JJ back to the parameters.

For parameter aa, there are two sources of gradient:

  1. Data path: JLpvuaJ \to L \to p \to v \to u \to a
  2. Regularization path: JRaJ \to R \to a

For parameter cc, we also have two sources:

  1. Data path: JLpvhcJ \to L \to p \to v \to h \to c
  2. Regularization path: JRcJ \to R \to c

The most critical detail is calculating how vv depends on uu. Because of the branching paths, the chain rule dictates that we sum the direct and indirect sensitivities:

vu=u(ch+u+d)=chu+1=2cu+1\frac{\partial v}{\partial u} = \frac{\partial}{\partial u}(c h + u + d) = c \frac{\partial h}{\partial u} + 1 = 2cu + 1

Deriving the gradient formulas

To avoid writing the same sequence repeatedly, we can define an intermediate variable representing the sensitivity of the loss to the pre-output vv:

δ=Lv=2(py)1(1+v)2\delta = \frac{\partial L}{\partial v} = 2(p - y) \cdot \frac{-1}{(1+v)^2}

Using δ\delta, the gradients for each parameter are:

  • Jd=δ\frac{\partial J}{\partial d} = \delta
  • Jc=δh+λc\frac{\partial J}{\partial c} = \delta \cdot h + \lambda c
  • Jb=δ(2cu+1)\frac{\partial J}{\partial b} = \delta \cdot (2cu + 1)
  • Ja=δ(2cu+1)x+λa\frac{\partial J}{\partial a} = \delta \cdot (2cu + 1) \cdot x + \lambda a

Notice how the gradient for aa combines the data gradient (which explicitly depends on the input xx) and the regularization gradient.

Iteration 1: Setting the foundation

We start with initial parameters: a=1,b=0,c=1,d=0a = 1, b = 0, c = 1, d = 0, and input x=2x = 2.

1. Forward Pass

  • u=1(2)+0=2u = 1(2) + 0 = 2
  • h=22=4h = 2^2 = 4
  • v=1(4)+2+0=6v = 1(4) + 2 + 0 = 6
  • p=1/(1+6)0.142857p = 1 / (1 + 6) \approx 0.142857
  • J0.8346J \approx 0.8346

2. Backward Pass

Computing δ\delta gives δ0.034985\delta \approx 0.034985. We then calculate the parameter gradients:

  • Jd0.034985\frac{\partial J}{\partial d} \approx 0.034985
  • Jc0.034985×4+0.1(1)=0.239942\frac{\partial J}{\partial c} \approx 0.034985 \times 4 + 0.1(1) = 0.239942
  • Jb0.034985×(2(1)(2)+1)=0.174927\frac{\partial J}{\partial b} \approx 0.034985 \times (2(1)(2) + 1) = 0.174927
  • Ja0.174927×2+0.1(1)=0.449854\frac{\partial J}{\partial a} \approx 0.174927 \times 2 + 0.1(1) = 0.449854

3. Parameter Update

Using gradient descent θθη\theta \leftarrow \theta - \eta \nabla:

  • a110.1(0.449854)=0.955015a_1 \approx 1 - 0.1(0.449854) = 0.955015
  • b100.1(0.174927)=0.017493b_1 \approx 0 - 0.1(0.174927) = -0.017493
  • c110.1(0.239942)=0.976006c_1 \approx 1 - 0.1(0.239942) = 0.976006
  • d100.1(0.034985)=0.003499d_1 \approx 0 - 0.1(0.034985) = -0.003499

The first iteration behaves exactly as expected. The loss creates a signal, it flows backwards, and the network parameters adjust to make pp slightly closer to y=1y=1.

Iteration 2: The zero-input scenario

Let us perform a second update, but this time assume the training loop feeds in a new data point where x=0x = 0.

1. Forward Pass

  • u=0.955(0)0.017493=0.017493u = 0.955(0) - 0.017493 = -0.017493
  • h0.000306h \approx 0.000306
  • v0.02069v \approx -0.02069
  • p1.0211p \approx 1.0211
  • J0.0936J \approx 0.0936

2. Backward Pass

Because pp is quite close to y=1y=1, our δ\delta reverses sign: δ0.044064\delta \approx -0.044064.

  • Jd0.044064\frac{\partial J}{\partial d} \approx -0.044064
  • Jc0.097587\frac{\partial J}{\partial c} \approx 0.097587
  • Jb0.042560\frac{\partial J}{\partial b} \approx -0.042560
  • Ja=δ(2cu+1)x+λa=0+0.1(0.955)0.0955\frac{\partial J}{\partial a} = \delta \cdot (2cu+1) \cdot x + \lambda a = \cdots \cdot 0 + 0.1(0.955) \approx 0.0955

Look closely at the gradient for aa. Because x=0x = 0, the entire error signal flowing back from the loss function is multiplied by zero. The data provides no information about how to update aa. However, the gradient is not zero. The regularization path λa\lambda a is still active, gently pulling the parameter back toward the origin.

3. Parameter Update

  • a20.9550150.1(0.0955)=0.945464a_2 \approx 0.955015 - 0.1(0.0955) = 0.945464
  • b20.0174930.1(0.04256)=0.013237b_2 \approx -0.017493 - 0.1(-0.04256) = -0.013237
  • c20.9760060.1(0.097587)=0.966247c_2 \approx 0.976006 - 0.1(0.097587) = 0.966247
  • d20.0034990.1(0.04406)=0.000908d_2 \approx -0.003499 - 0.1(-0.04406) = 0.000908

The main takeaway

Stepping through the computation graph manually grounds two extremely important concepts in deep learning:

  1. Branch and merge means sum the gradients. Because uu flowed into vv through two separate operational paths, finding the gradient required summing the partial derivatives across both branches. This is the structural heart of backpropagation.
  2. Regularization acts as an independent force. The weight decay path is disconnected from the data. Even when a training sample completely zeroes out the data gradient for a parameter, the regularization gradient continues to pull the weight toward zero.
Ask AI ✨