示例代码

#!/usr/bin/env python
# coding:utf-8

import numpy as np

def sigmoid(x, deriv = False):
    if (deriv):
        return x*(1-x)
    return 1/(1+np.exp(-x))

x = np.array([
    [0, 0, 1],
    [0, 1, 1],
    [1, 0, 1],
    [1, 1, 1],
    [0, 1, 0]
])

print 'x.size =', x.shape

y = np.array([
    [0],
    [1],
    [1],
    [0],
    [0]
])

print 'y.size =', y.shape

np.random.seed(1)

# between -1 and 1
w0 = 2 * np.random.random((3,4)) - 1
w1 = 2 * np.random.random((4, 1)) - 1

print 'w0 =', w0
print 'w1 =', w1


for j in range(600000):
    l0 = x
    l1 = sigmoid(np.dot(l0, w0))
    l2 = sigmoid(np.dot(l1, w1))
    l2_error = y - l2
    if (j % 10000) == 0:
        print 'Error:', np.mean(np.abs(l2_error))
    l2_delta = l2_error * sigmoid(l2, deriv = True)

    l1_error = l2_delta.dot(w1.T)
    l1_delta = l1_error * sigmoid(l1, deriv = True)

    w1 += l1.T.dot(l2_delta)
    w0 += l0.T.dot(l1_delta)

print 'w0 =', w0
print 'w1 =', w1

执行结果

$ python ml.py 
x.size = (5, 3)
y.size = (5, 1)
w0 = [[-0.16595599  0.44064899 -0.99977125 -0.39533485]
 [-0.70648822 -0.81532281 -0.62747958 -0.30887855]
 [-0.20646505  0.07763347 -0.16161097  0.370439  ]]
w1 = [[-0.5910955 ]
 [ 0.75623487]
 [-0.94522481]
 [ 0.34093502]]
Error: 0.489742959567
Error: 0.00811002819041
Error: 0.00554076512683
Error: 0.00445345411335
Error: 0.00381961242594
Error: 0.00339315597104
Error: 0.00308153875904
Error: 0.00284122583509
Error: 0.0026487182317
Error: 0.00249008114534
Error: 0.00235646378531
Error: 0.00224194151724
Error: 0.00214238045858
Error: 0.00205479752848
Error: 0.00197697929925
Error: 0.00190724447155
Error: 0.0018442901377
Error: 0.00178708902161
Error: 0.00173481886308
Error: 0.00168681270788
Error: 0.00164252316607
Error: 0.00160149622746
Error: 0.00156335175559
Error: 0.00152776873721
Error: 0.00149447397598
Error: 0.00146323331919
Error: 0.00143384477324
Error: 0.00140613304587
Error: 0.00137994517845
Error: 0.00135514702019
Error: 0.00133162035914
Error: 0.00130926057011
Error: 0.00128797467303
Error: 0.00126767971977
Error: 0.00124830144571
Error: 0.00122977313628
Error: 0.00121203466902
Error: 0.00119503170009
Error: 0.00117871497008
Error: 0.00116303970897
Error: 0.00114796512406
Error: 0.00113345395743
Error: 0.00111947210206
Error: 0.00110598826773
Error: 0.00109297368914
Error: 0.00108040187012
Error: 0.00106824835887
Error: 0.00105649054972
Error: 0.0010451075079
Error: 0.00103407981424
Error: 0.00102338942709
Error: 0.00101301955931
Error: 0.00100295456844
Error: 0.000993179858371
Error: 0.000983681791144
Error: 0.000974447607648
Error: 0.00096546535621
Error: 0.000956723828133
Error: 0.000948212499411
Error: 0.000939921477942
w0 = [[ 5.43675976  3.49120763 -6.66258203 -4.09547167]
 [ 4.11269646 -4.1215545  -7.6969078  -2.48695267]
 [-7.45206992 -2.39024462  3.25187366  5.24884334]]
w1 = [[ -9.72915103]
 [  4.17180398]
 [-14.97548701]
 [  7.44465433]]
$

results matching ""

    No results matching ""