TensorFlow 手动构建一个神经网络

news/2025/2/1 6:26:05 标签: tensorflow, 神经网络, 人工智能


TensorFlowKeras 来构建和训练一个简单的神经网络模型。我们来逐行解析它的功能

import tensorflow as tf
import numpy as np
  • tensorflow:导入 TensorFlow 库,TensorFlow 是一个开源的机器学习框架。
  • numpy:导入 NumPy 库,它是 Python 中用于进行数组操作和科学计算的基础库。

构建神经网络模型

l1 = tf.keras.layers.Dense(units=3, activation='sigmoid')
l2 = tf.keras.layers.Dense(units=1, activation='sigmoid')
model = tf.keras.Sequential([l1, l2])
  • l1 = tf.keras.layers.Dense(units=3, activation='sigmoid'):创建一个 全连接层(Dense Layer),该层有 3 个输出节点(units=3),使用 sigmoid 激活函数。这意味着该层将输出 3 个值,每个值在 0 和 1 之间。

  • l2 = tf.keras.layers.Dense(units=1, activation='sigmoid'):创建第二个全连接层,包含 1 个输出节点,并使用 sigmoid 激活函数。该层输出一个值,表示模型的最终输出。

  • model = tf.keras.Sequential([l1, l2]):创建一个 顺序模型(Sequential Model),将之前定义的两层(l1l2)按顺序连接起来。这样,输入数据会先经过 l1 层,然后传递到 l2 层,最后得到输出。

定义优化器与模型编译

sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
  • sgd = tf.keras.optimizers.SGD(learning_rate=0.9):定义了一个 随机梯度下降优化器(SGD Optimizer),它的学习率为 0.9。学习率决定了在每次更新参数时步长的大小,较高的学习率可以加快训练,但也有可能导致模型震荡或不稳定。

  • model.compile(...):编译模型,指定:

    • optimizer=sgd:使用之前定义的 SGD 优化器。
    • loss='binary_crossentropy':选择 二元交叉熵损失函数,适用于二分类问题。它衡量了预测概率和实际标签之间的差异。
    • metrics=['accuracy']:设置模型评估指标为 准确率(accuracy)。

准备数据

x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1], [0.7, 0.7], [0.7, -0.7], [-0.7, -0.7], [-0.7, 0.7]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0])
  • x:这是输入数据集,每一行是一个样本,包含两个特征。可以理解为每个样本是一个二维空间中的坐标点(例如,[1,1][1,-1] 等)。
  • y:这是标签数据集,表示每个输入数据对应的类别。1 表示正类,0 表示负类。

训练模型

model.fit(x, y, epochs=1000)
  • model.fit(x, y, epochs=1000):开始训练模型,使用 x 作为输入数据,y 作为目标标签,训练 1000 个周期(epochs)。每个周期,模型都会根据输入数据和真实标签进行前向传播和反向传播,更新模型权重。

保存模型(使用 Keras 格式):

# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')

总结:

这段代码构建了一个简单的 二分类神经网络模型,其输入数据有两个特征,模型有两层:

  1. 第一层有 3 个节点,使用 sigmoid 激活函数。
  2. 第二层有 1 个节点,输出一个值(0 或 1),表示类别。

训练过程中,使用 二元交叉熵损失函数随机梯度下降优化器,并计算 准确率 作为评估指标。

模型将根据输入的 8 个样本训练 1000 个周期,以便学习如何将输入数据映射到相应的类别(10)。

完整代码
test.py 训练模型

import tensorflow as tf
import numpy as np
# 创建int32类型的0维张量,即标量
l1=tf.keras.layers.Dense(units=3,activation='sigmoid')
l2=tf.keras.layers.Dense(units=1,activation='sigmoid')
model=tf.keras.Sequential([l1,l2])
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
x=np.array([[1,1],[1,-1],[-1,1],[-1,-1],[0.7,0.7],[0.7,-0.7],[-0.7,-0.7],[-0.7,0.7]])
y=np.array([1,1,1,1,0,0,0,0])
model.fit(x,y,epochs=2000)
# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')

 test2.py加载模型并进行预测:

import tensorflow as tf
import numpy as np

# 加载训练好的模型
model = tf.keras.models.load_model('my_model.keras')

# 预测数据
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])

# 获取预测结果
predictions = model.predict(nx)

# 输出预测结果
print(predictions)

# 如果需要将概率转化为类别(0或1)
predicted_classes = (predictions > 0.9).astype(int)

# 输出最终的类别预测
print(predicted_classes)

视频分享
初识TensorFlow 
https://v.douyin.com/ifG2mmLH/
复制此链接,打开Dou音搜索,直接观看视频!


http://www.niftyadmin.cn/n/5839066.html

相关文章

27.Word:财务软件应用的书稿【10】

目录 NO1.2 NO3 NO5.6​ NO7.8​ NO9​ 存在页码链接关系,只是页码格式不同 NO1.2 另存为/F12:考生文件夹布局→页面设置对话框→页边距:上下内外/装订线→纸张大小→布局:页眉页脚 NO3 样式的应用:超快速❗ 开…

RRT_STAR路径规划代码

这是一段使用MATLAB编写的代码,实现了一个基于RRT*(Rapidly-exploring Random Trees Star)算法的路径规划。RRT*是一种用于在配置空间中搜索路径的采样算法,常用于机器人路径规划等领域。以下是代码的主要功能和结构: …

vue之pinia组件的使用

1、搭建pinia环境 cnpm i pinia #安装pinia的组件 cnpm i nanoid #唯一id,相当于uuid cnpm install axios #网络请求组件 2、存储读取数据 存储数据 >> Count.ts文件import {defineStore} from piniaexport const useCountStore defineStore(count,{// a…

【蓝桥杯】43697.机器人塔

题目描述 X 星球的机器人表演拉拉队有两种服装,A 和 B。 他们这次表演的是搭机器人塔。 类似: A B B A B A A A B B B B B A B A B A B B A 队内的组塔规则是: A 只能站在 AA 或 BB 的肩上。 B 只能站在 AB 或 BA 的肩上。 你的…

1 HDFS

1 HDFS 1. HDFS概述2. HDFS架构3. HDFS的特性4. HDFS 的命令行使用5. hdfs的高级使用命令6. HDFS 的 block 块和副本机制6.1 抽象为block块的好处6.2 块缓存6.3 hdfs的文件权限验证6.4 hdfs的副本因子 7. HDFS 文件写入过程(非常重要)7.1 网络拓扑概念7.…

如何获取Springboot项目运行路径 (idea 启动以及打包为jar均可) 针对无服务器容器新建上传文件路径(适用于win 与 linunix)

public class Constants {public static String getUploadDir() {// 获取 JAR 包所在目录ApplicationHome home new ApplicationHome(Constants.class);File jarDir home.getDir();// 构建上传文件存储路径(JAR 同级目录下的 uploads 文件夹)File uplo…

漏洞扫描工具之xray

下载地址:https://github.com/chaitin/xray/releases 1.9.11 使用文档:https://docs.xray.cool/tools/xray/Scanning 与burpsuite联动: https://xz.aliyun.com/news/7563 参考:https://blog.csdn.net/lza20001103/article/details…

sublime_text的快捷键

sublime_text的快捷键 向下复制, 复制光标所在整行并插入到下一行:通过 CtrlShiftD 实现快速复制当前行的功能。 可选多行, 不选则复制当前行 ctrl Shift D 删除当前行:通过 CtrlShiftK 实现快速删除当前行的功能。 可选多行, 不选则删当前行 ctrl S…