示例代码
#!/usr/bin/env python
# coding:utf-8
import numpy as np
def sigmoid(x, deriv = False):
if (deriv):
return x*(1-x)
return 1/(1+np.exp(-x))
x = np.array([
[0, 0, 1],
[0, 1, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 0]
])
print 'x.size =', x.shape
y = np.array([
[0],
[1],
[1],
[0],
[0]
])
print 'y.size =', y.shape
np.random.seed(1)
# between -1 and 1
w0 = 2 * np.random.random((3,4)) - 1
w1 = 2 * np.random.random((4, 1)) - 1
print 'w0 =', w0
print 'w1 =', w1
for j in range(600000):
l0 = x
l1 = sigmoid(np.dot(l0, w0))
l2 = sigmoid(np.dot(l1, w1))
l2_error = y - l2
if (j % 10000) == 0:
print 'Error:', np.mean(np.abs(l2_error))
l2_delta = l2_error * sigmoid(l2, deriv = True)
l1_error = l2_delta.dot(w1.T)
l1_delta = l1_error * sigmoid(l1, deriv = True)
w1 += l1.T.dot(l2_delta)
w0 += l0.T.dot(l1_delta)
print 'w0 =', w0
print 'w1 =', w1
执行结果
$ python ml.py
x.size = (5, 3)
y.size = (5, 1)
w0 = [[-0.16595599 0.44064899 -0.99977125 -0.39533485]
[-0.70648822 -0.81532281 -0.62747958 -0.30887855]
[-0.20646505 0.07763347 -0.16161097 0.370439 ]]
w1 = [[-0.5910955 ]
[ 0.75623487]
[-0.94522481]
[ 0.34093502]]
Error: 0.489742959567
Error: 0.00811002819041
Error: 0.00554076512683
Error: 0.00445345411335
Error: 0.00381961242594
Error: 0.00339315597104
Error: 0.00308153875904
Error: 0.00284122583509
Error: 0.0026487182317
Error: 0.00249008114534
Error: 0.00235646378531
Error: 0.00224194151724
Error: 0.00214238045858
Error: 0.00205479752848
Error: 0.00197697929925
Error: 0.00190724447155
Error: 0.0018442901377
Error: 0.00178708902161
Error: 0.00173481886308
Error: 0.00168681270788
Error: 0.00164252316607
Error: 0.00160149622746
Error: 0.00156335175559
Error: 0.00152776873721
Error: 0.00149447397598
Error: 0.00146323331919
Error: 0.00143384477324
Error: 0.00140613304587
Error: 0.00137994517845
Error: 0.00135514702019
Error: 0.00133162035914
Error: 0.00130926057011
Error: 0.00128797467303
Error: 0.00126767971977
Error: 0.00124830144571
Error: 0.00122977313628
Error: 0.00121203466902
Error: 0.00119503170009
Error: 0.00117871497008
Error: 0.00116303970897
Error: 0.00114796512406
Error: 0.00113345395743
Error: 0.00111947210206
Error: 0.00110598826773
Error: 0.00109297368914
Error: 0.00108040187012
Error: 0.00106824835887
Error: 0.00105649054972
Error: 0.0010451075079
Error: 0.00103407981424
Error: 0.00102338942709
Error: 0.00101301955931
Error: 0.00100295456844
Error: 0.000993179858371
Error: 0.000983681791144
Error: 0.000974447607648
Error: 0.00096546535621
Error: 0.000956723828133
Error: 0.000948212499411
Error: 0.000939921477942
w0 = [[ 5.43675976 3.49120763 -6.66258203 -4.09547167]
[ 4.11269646 -4.1215545 -7.6969078 -2.48695267]
[-7.45206992 -2.39024462 3.25187366 5.24884334]]
w1 = [[ -9.72915103]
[ 4.17180398]
[-14.97548701]
[ 7.44465433]]
$