In [85]:
import numpy as np
import math
import theano
import theano.tensor as T

x_1 = T.dvector('x_1')
i_1 = T.dscalar('i_1')
f_1 = T.dscalar('f_1')
o_1 = T.dscalar('o_1')
b_1 = T.dscalar('b_1')
V_a1 = T.dvector('V_a1')
h_1_prev = T.dvector('h_1_prev')
s_1_prev = T.dscalar('s_1_prev')
U_a1 = T.dvector('U_a1')
a_1 = T.tanh(T.sum(V_a1 * x_1) + T.sum(U_a1 * h_1_prev) + b_1)
assert(a_1.ndim == 0) # scalar is a 0-dim tensor
s_1 = i_1 * a_1 + f_1 * s_1_prev
h_1 = o_1 * T.tanh(s_1)

Automatic differentiation of $\frac{\partial s^1_1}{\partial a^1_1}$

In [91]:
gs_1 = T.grad(s_1, a_1)
f_gs_1 = theano.function([i_1, f_1, s_1_prev, V_a1, x_1, U_a1, h_1_prev, b_1], gs_1) 
# We have to provide enough expressions. Otherwise, Theano returns an error

print(f_gs_1(2, 1, 1, [1, 1], [1, 1], [1, 1], [1, 1], 1)) # dh_1/ds_1
print(f_gs_1(1, 1, 1, [1, 1], [1, 1], [1, 1], [1, 1], 1)) # dh_1/ds_1
2.0
1.0

Automatic differentiation of $\frac{\partial h^1_1}{\partial s^1_1}$

In [87]:
gh_1 = T.grad(h_1, s_1)
f_gh_1 = theano.function([o_1, s_1], gh_1)
print(f_gh_1(1, 1.9998)) # dh_1/ds_1
0.07067807364430378

Automatic differentiation of $\frac{\partial a^1_1}{\partial V_{a1}}$

In [88]:
ga_1 = T.grad(a_1, V_a1)
f_ga_1 = theano.function([V_a1, x_1, U_a1, h_1_prev, b_1], ga_1)
print(f_ga_1([1, 1], [1, 1], [1, 1], [1, 1], 1)) 
[ 0.00018158  0.00018158]

Automatic differentiation of $\frac{\partial h^1_1}{\partial V_{a1}}$

In [89]:
gh_1 = T.grad(h_1, V_a1)
#f_gh_1 = theano.function([V_a1, x_1, U_a1, h_1_prev, b_1, f_1, i_1, o_1, s_1_prev], gs_1)
f_gh_1 = theano.function([V_a1, x_1, U_a1, h_1_prev, b_1, f_1, i_1, o_1, s_1_prev], gh_1)
print(f_gh_1([1, 1], [1, 1], [1, 1], [1, 1], 1, 1, 1, 1, 1))
[  1.28312511e-05   1.28312511e-05]

Checking if it matches to the hand-computed results

Since $$ \frac{\partial h^1_1}{\partial V_{a1}} = \frac{\partial h^1_1}{\partial s^1_1} \frac{\partial s^1_1}{\partial a^1_1} \frac{\partial a^1_1}{\partial V_{a1}} $$

In [93]:
0.0707 * 1 * np.array([0.00018158, 0.00018158])
Out[93]:
array([  1.28377060e-05,   1.28377060e-05])