自定义数据集使用scikit-learn中的包实现线性回归方法对其进行拟合

news/2025/1/31 12:06:19 标签: scikit-learn, 线性回归, python

一、导入必要的库

python">import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

二、加载自定义数据集

python"># 创建自定义数据集
# 假设我们有一个简单的线性关系 y = 2x + 1,并添加一些随机噪声
np.random.seed(42)  # 为了结果的可重复性设置随机种子
X = 2 * np.random.rand(100, 1)  # 100个样本,每个样本1个特征(随机生成在0到2之间的数)
y = 4 + 3 * X + np.random.randn(100, 1)  # 目标变量,添加了一些随机噪声

三、划分数据集

python">X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

四、训练线性回归模型

python">lin_reg = LinearRegression()
lin_reg.fit(X_train, y_train)

五、预测并评估模型

python"># 进行预测
y_pred = lin_reg.predict(X_test)

# 评估模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

六、图形展示

python">plt.scatter(X, y, color='blue', label='point')
plt.plot(X_test, y_pred, color='red', label='line')
plt.xlabel('X')
plt.ylabel('y')
plt.title('show')
plt.legend()
plt.show()

七、完整代码即结果演示

 

python">import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

# 创建自定义数据集
# 假设我们有一个简单的线性关系 y = 2x + 1,并添加一些随机噪声
np.random.seed(42)  # 为了结果的可重复性设置随机种子
X = 2 * np.random.rand(100, 1)  # 100个样本,每个样本1个特征(随机生成在0到2之间的数)
y = 4 + 3 * X + np.random.randn(100, 1)  # 目标变量,添加了一些随机噪声

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练线性回归模型
lin_reg = LinearRegression()
lin_reg.fit(X_train, y_train)

# 进行预测
y_pred = lin_reg.predict(X_test)

# 评估模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"均方误差: {mse}")

plt.scatter(X, y, color='blue', label='point')
plt.plot(X_test, y_pred, color='red', label='line')
plt.xlabel('X')
plt.ylabel('y')
plt.title('show')
plt.legend()
plt.show()

 

 

 

 

 


http://www.niftyadmin.cn/n/5838643.html

相关文章

Java基础知识总结(三十二)--API--- java.lang.Runtime

类中没有构造方法,不能创建对象。 但是有非静态方法。说明该类中应该定义好了对象,并可以通过一个static方法获取这个对象。用这个对象来调用非静态方法。这个方法就是 static Runtime getRuntime(); 这个Runtime其实使用单例设计模式进行设计。 class …

基于Springboot的社区药房管理系统

博主介绍:java高级开发,从事互联网行业多年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了多年的设计程序开发,开发过上千套设计程序,没有什么华丽的语言,只有实…

Python NumPy(8):NumPy 位运算、NumPy 字符串函数

1 NumPy 位运算 位运算是一种在二进制数字的位级别上进行操作的一类运算,它们直接操作二进制数字的各个位,而不考虑数字的整体值。NumPy 提供了一系列位运算函数,允许对数组中的元素进行逐位操作,这些操作与 Python 的位运算符类似…

网络安全技术简介

网络安全技术简介 随着信息技术的迅猛发展,互联网已经成为人们日常生活和工作中不可或缺的一部分。与此同时,网络安全问题也日益凸显,成为全球关注的焦点。无论是个人隐私泄露、企业数据被盗取还是国家信息安全受到威胁,都与网络…

[Python学习日记-80] 用 socket 实现文件传输功能(上传下载)

[Python学习日记-80] 用 socket 实现文件传输功能(上传下载) 简介 简单版本 函数版本 面向对象版本 简介 到此为止网络编程基础的介绍已经接近尾声了,而在本篇当中我们会基于上一篇博客代码的基础上来实现文件传输功能。文件传输其实与远…

单片机串口打印printf函数显示内容(固件库开发)

1.hal_usart.c 文件 #include <stdio.h> #include "hal_usart.h" #include "stm32F10x.h"//**要根据 使用的是哪个串口 对应修改 串口号 eg&#xff1a;USART1** void USART_PUTC(char ch) {/* 等待数据寄存器为空 */while((USART1->SR & …

LLMs之WebRAG:STORM/Co-STORM的简介、安装和使用方法、案例应用之详细攻略

LLMs之WebRAG&#xff1a;STORM/Co-STORM的简介、安装和使用方法、案例应用之详细攻略 目录 STORM系统简介 1、Co-STORM 2、更新新闻 STORM系统安装和使用方法 1、安装 pip安装 直接克隆GitHub仓库 2、模型和数据集 两个数据集 FreshWiki数据集 WildSeek数据集 支持…

java求职学习day20

1 在线考试系统 1.1 软件开发的流程 需求分析文档、概要设计文档、详细设计文档、编码和测试、安装和调试、维护和升级 1.2 软件的需求分析 在线考试系统的主要功能分析如下&#xff1a; &#xff08; 1 &#xff09;学员系统 &#xff08;1.1&#xff09;用户模块&…