4.5 手写神经网络
解释什么是神经网络,并使用 Python 从头实现!
创建日期: 2024-03-18
你可能会惊讶,神经网络 (Neural Network) 并不复杂。神经网络这个术语无处不在,但实际上它比人们想象中要简单得多。
本节内容完全针对初学者,假设之前没有机器学习方面的经验。我们会理解神经网络是如何工作的,并使用 Python 从头实现一个。让我们开始吧!
4.5.1 神经元
首先需要讨论 神经元 (Neuron) , 它是神经网络的基本单位。神经元接收输入,对其进行一些数学运算,然后产生一个输出。以下是一个神经元有两个输入的样子:
在这个神经元中有三种计算:
-
1. 每个输入乘以权重(红色表示):
\(x_1 \to x_1 \times w_1, \quad x_2 \to x_2 \times x_2\)
-
2. 乘以权重后的输入和偏置(绿色表示)相加:
\((x_1 \times w_1) + (x_2 \times w_2) + b\)
-
3. 相加得到的和通过激活函数(黄色表示):
\( y = f((x_1 \times w_1) + (x_2 \times w_2) + b)\)
激活函数的作用就是将一个无界的输入转换为一个具有良好、可预测形式的输出。常用的激活函数有 sigmoid 函数,在 第 4.3.1 小节 Sigmoid 中已经介绍,它将 \((-, +)\) 的输入数据压缩到 \((0, 1)\) 之间。
4.5.2 神经网络
神经网络就是一组神经元的相互连接,一个简单的神经网络可能是如下的样子:
这个神经网络有两个输入 \((x_1, x_2)\) ,一个由神经元 \((h_1, h_2)\) 组成隐藏层,一个由神经元 \(o_1\) 组成的输出层。注意 \(o_1\) 的输入来自于 \(h_1\), \(h_2\) 的输出,这就构成了网络。
假设所有的神经元的有相同的初始权重 \(w = [0, 1]\) ,相同的初始偏置 \(b = 0\) ,使用相同的 sigmoid 激活函数,\(h_1\), \(h_2\), \(o_1\) 表示神经元的输出。
将输入 \(x = [2, 3]\) 传递到神经网络中:
\(h_1 = h_2 = f(w \times x + b) = f((0 \times 2) + (1 \times 3) + 0) = f(3) = 0.952574\)
\(o_1 = f(w \times [h_1, h_2] + b) = f((0 \times h_1) + (1 \times h_2) + 0) = f(0.9526) = 0.7216\)
对于输入 \(x = [2, 3]\) 来说,神经网络的输出是 0.7216,是不是看上去非常简单?
神经网络可以是任意数量的层,这些层可以有任意数量的神经元。基本思想保持不变:将输入通过网络中的神经元向前传播,最终获得输出。为简单起见,我们将在本文的剩余部分继续使用上图所示的网络。
文件 5_2_neuron_network.py 使用代码展示简单的神经网络的样子:
import numpy
def sigmoid(x):
return 1 / (1 + numpy.exp(-x))
class Neuron:
def __init__(self, weights, bias):
self.weights = weights
self.bias = bias
def feedforward(self, inputs):
total = numpy.dot(self.weights, inputs) + self.bias
return sigmoid(total)
class OurNeuralNetwork:
def __init__(self):
weights = numpy.array([0, 1])
bias = 0
self.h1 = Neuron(weights, bias)
self.h2 = Neuron(weights, bias)
self.o1 = Neuron(weights, bias)
def feedforward(self, x):
out_h1 = self.h1.feedforward(x)
out_h2 = self.h2.feedforward(x)
out_o1 = self.o1.feedforward(numpy.array([out_h1, out_h2]))
return out_o1
if __name__ == '__main__':
network = OurNeuralNetwork()
x = numpy.array([2, 3])
# 0.7216325609518421
print(network.feedforward(x))
我们再次获得 0.7216 ,看上去网络是工作的!