Introduction to TheanoΒΆ

Credits: Forked from summerschool2015 by mila-udem

SlidesΒΆ

Refer to the associated Introduction to Theano slides and use this notebook for hands-on practice of the concepts.

Basic usageΒΆ

Defining an expressionΒΆ

import theano
from theano import tensor as T
x = T.vector('x')
W = T.matrix('W')
b = T.vector('b')
dot = T.dot(x, W)
out = T.nnet.sigmoid(dot + b)

Graph visualizationΒΆ

from theano.printing import debugprint
debugprint(dot)
debugprint(out)

Compiling a Theano functionΒΆ

f = theano.function(inputs=[x, W], outputs=dot)
g = theano.function([x, W, b], out)
h = theano.function([x, W, b], [dot, out])
i = theano.function([x, W, b], [dot + b, out])

Graph visualizationΒΆ

debugprint(f)
debugprint(g)
from theano.printing import pydotprint
pydotprint(f, outfile='pydotprint_f.png')
from IPython.display import Image
Image('pydotprint_f.png', width=1000)
pydotprint(g, outfile='pydotprint_g.png')
Image('pydotprint_g.png', width=1000)
pydotprint(h, outfile='pydotprint_h.png')
Image('pydotprint_h.png', width=1000)

Executing a Theano functionΒΆ

import numpy as np
np.random.seed(42)
W_val = np.random.randn(4, 3)
x_val = np.random.rand(4)
b_val = np.ones(3)

f(x_val, W_val)
g(x_val, W_val, b_val)
h(x_val, W_val, b_val)
i(x_val, W_val, b_val)

Graph definition and SyntaxΒΆ

Graph structureΒΆ

pydotprint(f, compact=False, outfile='pydotprint_f_notcompact.png')
Image('pydotprint_f_notcompact.png', width=1000)

Strong typingΒΆ

Broadcasting tensorsΒΆ

r = T.row('r')
print(r.broadcastable)
c = T.col('c')
print(c.broadcastable)
f = theano.function([r, c], r + c)
print(f([[1, 2, 3]], [[.1], [.2]]))

Graph TransformationsΒΆ

Substitution and CloningΒΆ

The givens keywordΒΆ

x_ = T.vector('x_')
x_n = (x_ - x_.mean()) / x_.std()
f_n = theano.function([x_, W], dot, givens={x: x_n})
f_n(x_val, W_val)

Cloning with replacementΒΆ

dot_n, out_n = theano.clone([dot, out], replace={x: (x - x.mean()) / x.std()})                        
f_n = theano.function([x, W], dot_n)                                                                  
f_n(x_val, W_val)

GradientΒΆ

Using theano.gradΒΆ

y = T.vector('y')
C = ((out - y) ** 2).sum()
dC_dW = theano.grad(C, W)
dC_db = theano.grad(C, b)
# dC_dW, dC_db = theano.grad(C, [W, b])

Using the gradientsΒΆ

cost_and_grads = theano.function([x, W, b, y], [C, dC_dW, dC_db])
y_val = np.random.uniform(size=3)
print(cost_and_grads(x_val, W_val, b_val, y_val))
upd_W = W - 0.1 * dC_dW
upd_b = b - 0.1 * dC_db
cost_and_upd = theano.function([x, W, b, y], [C, upd_W, upd_b])
print(cost_and_upd(x_val, W_val, b_val, y_val))
pydotprint(cost_and_upd, outfile='pydotprint_cost_and_upd.png')
Image('pydotprint_cost_and_upd.png', width=1000)

Shared variablesΒΆ

Update valuesΒΆ

C_val, dC_dW_val, dC_db_val = cost_and_grads(x_val, W_val, b_val, y_val)
W_val -= 0.1 * dC_dW_val
b_val -= 0.1 * dC_db_val

C_val, W_val, b_val = cost_and_upd(x_val, W_val, b_val, y_val)

Using shared variablesΒΆ

x = T.vector('x')
y = T.vector('y')
W = theano.shared(W_val)
b = theano.shared(b_val)
dot = T.dot(x, W)
out = T.nnet.sigmoid(dot + b)
f = theano.function([x], dot)  # W is an implicit input
g = theano.function([x], out)  # W and b are implicit inputs
print(f(x_val))
print(g(x_val))

Updating shared variablesΒΆ

C = ((out - y) ** 2).sum()
dC_dW, dC_db = theano.grad(C, [W, b])
upd_W = W - 0.1 * dC_dW
upd_b = b - 0.1 * dC_db

cost_and_perform_updates = theano.function(
    inputs=[x, y],
    outputs=C,
    updates=[(W, upd_W),
             (b, upd_b)])
pydotprint(cost_and_perform_updates, outfile='pydotprint_cost_and_perform_updates.png')
Image('pydotprint_cost_and_perform_updates.png', width=1000)

Advanced TopicsΒΆ

Extending TheanoΒΆ

The easy way: PythonΒΆ

import theano
import numpy
from theano.compile.ops import as_op

def infer_shape_numpy_dot(node, input_shapes):
    ashp, bshp = input_shapes
    return [ashp[:-1] + bshp[-1:]]

@as_op(itypes=[theano.tensor.fmatrix, theano.tensor.fmatrix],
       otypes=[theano.tensor.fmatrix], infer_shape=infer_shape_numpy_dot)
def numpy_dot(a, b):
   return numpy.dot(a, b)