sklearn中MLPClassifier源码解析
2022/9/13 1:23:08
本文主要是介绍sklearn中MLPClassifier源码解析,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
神经网络
.fit()
首先传入类私用方法._fit()
- 确定hidden_layer_size是可迭代的
- 调用_validate_hyperparameters验证超参数是否合法
- 验证输入的x和y是否合法并且获取one-hot-label
- 从x、y中获取输入参数的信息,并且添加输入层和输出层
(隐藏层作为参数,输入层和输出层可以从x、y中获取) - 将随机种子(seed)变成np.random.RandomState实例
- 看一看是不是第一次训练该模型,如果是则进入_initialize方法
- 初始化weight和bias
- 初始化loss和score
# factor作为边界计算的分子 # 6.0说明是分类任务 factor = 6.0 if self.activation == "logistic": factor = 2.0 # fan_in、fan_out分别是这一层的输入大小和输出大小 init_bound = np.sqrt(factor / (fan_in + fan_out)) # _random_state就是第五点中根据seed实例出的对象 # uniform代表在随机生成,参数分别为下限,上限,size # Generate weights and bias coef_init = self._random_state.uniform( -init_bound, init_bound, (fan_in, fan_out) ) intercept_init = self._random_state.uniform(-init_bound, init_bound, fan_out)
- 初始化权重和偏值的梯度(用numpy.empty)
- 训练(根据solver决定模型 _fit_stochastic 和 _fit_lbfgs)# 下次再读,到饭点了
- 验证权重是否合法,用np.isfinite(),检查是否出现INF,-INF和NAN
这篇关于sklearn中MLPClassifier源码解析的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-05-01巧用 TiCDC Syncpoint 构建银行实时交易和准实时计算一体化架构
- 2024-05-01银行核心背后的落地工程体系丨Oracle - TiDB 数据迁移详解
- 2024-04-26高性能表格工具VTable总体构成-icode9专业技术文章分享
- 2024-04-16软路由代理问题, tg 无法代理问题-icode9专业技术文章分享
- 2024-04-16程序猿用什么锅-icode9专业技术文章分享
- 2024-04-16自建 NAS 的方案-icode9专业技术文章分享
- 2024-04-14ansible 在远程主机上执行脚本,并传入参数-icode9专业技术文章分享
- 2024-04-14ansible 在远程主机上执行脚本,并传入参数, 加上remote_src: yes 配置-icode9专业技术文章分享
- 2024-04-14ansible 检测远程主机的8080端口,如果关闭,则echo 进程已关闭-icode9专业技术文章分享
- 2024-04-14result 成功怎么写-icode9专业技术文章分享