您现在的位置是:网站首页> 编程资料编程资料
python目标检测SSD算法训练部分源码详解_python_
2023-05-26
402人已围观
简介 python目标检测SSD算法训练部分源码详解_python_
学习前言
……又看了很久的SSD算法,今天讲解一下训练部分的代码。
预测部分的代码可以参照https://www.jb51.net/article/246905.htm
讲解构架
本次教程的讲解主要是对训练部分的代码进行讲解,该部分讲解主要是对训练函数的执行过程与执行思路进行详解。
训练函数的执行过程大体上分为:
1、设定训练参数。
2、读取数据集。
3、建立ssd网络。
4、预处理数据集。
5、对ground truth实际框进行编码,使其格式符合神经网络的预测结果,便于比较。
6、计算loss值。
7、利用优化器完成梯度下降并保存模型。
在看本次算法前,建议先下载我简化过的源码,配合观看,具体运行方法在开始训练部分
下载链接 https://pan.baidu.com/s/1K4RAJvLj11blywuX2CrLSA
提取码:4wbi
模型训练的流程
本文使用的ssd_vgg_300的源码点击下载,本文对其进行了简化,保留了上一次筛选出的预测部分,还加入了训练部分,便于理顺整个SSD的框架。
1、设置参数
在载入数据库前,首先要设定一系列的参数,这些参数可以分为几个部分。第一部分是SSD网络中的一些标志参数:
# =========================================================================== # # SSD Network flags. # =========================================================================== # # localization框的衰减比率 tf.app.flags.DEFINE_float( 'loss_alpha', 1., 'Alpha parameter in the loss function.') # 正负样本比率 tf.app.flags.DEFINE_float( 'negative_ratio', 3., 'Negative ratio in the loss function.') # ground truth处理后,匹配得分高于match_threshold属于正样本 tf.app.flags.DEFINE_float( 'match_threshold', 0.5, 'Matching threshold in the loss function.')
第二部分是训练时的参数(包括训练效果输出、保存方案等):
# =========================================================================== # # General Flags. # =========================================================================== # # train_dir用于保存训练后的模型和日志 tf.app.flags.DEFINE_string( 'train_dir', '/tmp/tfmodel/', 'Directory where checkpoints and event logs are written to.') # num_readers是在对数据集进行读取时所用的平行读取器个数 tf.app.flags.DEFINE_integer( 'num_readers', 4, 'The number of parallel readers that read data from the dataset.') # 在进行训练batch的构建时,所用的线程数 tf.app.flags.DEFINE_integer( 'num_preprocessing_threads', 4, 'The number of threads used to create the batches.') # 每十步进行一次log输出,在窗口上 tf.app.flags.DEFINE_integer( 'log_every_n_steps', 10, 'The frequency with which logs are print.') # 每600秒存储一次记录 tf.app.flags.DEFINE_integer( 'save_summaries_secs', 600, 'The frequency with which summaries are saved, in seconds.') # 每600秒存储一次模型 tf.app.flags.DEFINE_integer( 'save_interval_secs', 600, 'The frequency with which the model is saved, in seconds.') # 可以使用的gpu内存数量 tf.app.flags.DEFINE_float( 'gpu_memory_fraction', 0.7, 'GPU memory fraction to use.')
第三部分是优化器参数:
# =========================================================================== # # Optimization Flags. # =========================================================================== # # 优化器参数 # weight_decay参数 tf.app.flags.DEFINE_float( 'weight_decay', 0.00004, 'The weight decay on the model weights.') # 使用什么优化器 tf.app.flags.DEFINE_string( 'optimizer', 'rmsprop', 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' '"ftrl", "momentum", "sgd" or "rmsprop".') tf.app.flags.DEFINE_float( 'adadelta_rho', 0.95, 'The decay rate for adadelta.') tf.app.flags.DEFINE_float( 'adagrad_initial_accumulator_value', 0.1, 'Starting value for the AdaGrad accumulators.') tf.app.flags.DEFINE_float( 'adam_beta1', 0.9, 'The exponential decay rate for the 1st moment estimates.') tf.app.flags.DEFINE_float( 'adam_beta2', 0.999, 'The exponential decay rate for the 2nd moment estimates.') tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.') tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5, 'The learning rate power.') tf.app.flags.DEFINE_float( 'ftrl_initial_accumulator_value', 0.1, 'Starting value for the FTRL accumulators.') tf.app.flags.DEFINE_float( 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.') tf.app.flags.DEFINE_float( 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.') tf.app.flags.DEFINE_float( 'momentum', 0.9, 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.') tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') 第四部分是学习率参数:
# =========================================================================== # # Learning Rate Flags. # =========================================================================== # # 学习率衰减的方式,有固定、指数衰减等 tf.app.flags.DEFINE_string( 'learning_rate_decay_type', 'exponential', 'Specifies how the learning rate is decayed. One of "fixed", "exponential",' ' or "polynomial"') # 初始学习率 tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') # 结束时的学习率 tf.app.flags.DEFINE_float( 'end_learning_rate', 0.0001, 'The minimal end learning rate used by a polynomial decay learning rate.') tf.app.flags.DEFINE_float( 'label_smoothing', 0.0, 'The amount of label smoothing.') # 学习率衰减因素 tf.app.flags.DEFINE_float( 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.') tf.app.flags.DEFINE_float( 'num_epochs_per_decay', 2.0, 'Number of epochs after which learning rate decays.') tf.app.flags.DEFINE_float( 'moving_average_decay', None, 'The decay to use for the moving average.' 'If left as None, then moving averages are not used.') 第五部分是数据集参数:
# =========================================================================== # # Dataset Flags. # =========================================================================== # # 数据集名称 tf.app.flags.DEFINE_string( 'dataset_name', 'imagenet', 'The name of the dataset to load.') # 数据集种类个数 tf.app.flags.DEFINE_integer( 'num_classes', 21, 'Number of classes to use in the dataset.') # 训练还是测试 tf.app.flags.DEFINE_string( 'dataset_split_name', 'train', 'The name of the train/test split.') # 数据集目录 tf.app.flags.DEFINE_string( 'dataset_dir', None, 'The directory where the dataset files are stored.') tf.app.flags.DEFINE_integer( 'labels_offset', 0, 'An offset for the labels in the dataset. This flag is primarily used to ' 'evaluate the VGG and ResNet architectures which do not use a background ' 'class for the ImageNet dataset.') tf.app.flags.DEFINE_string( 'model_name', 'ssd_300_vgg', 'The name of the architecture to train.') tf.app.flags.DEFINE_string( 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 'as `None`, then the model_name flag is used.') # 每一次训练batch的大小 tf.app.flags.DEFINE_integer( 'batch_size', 32, 'The number of samples in each batch.') # 训练图片的大小 tf.app.flags.DEFINE_integer( 'train_image_size', None, 'Train image size') # 最大训练次数 tf.app.flags.DEFINE_integer('max_number_of_steps', 50000, 'The maximum number of training steps.') 第六部分是微修已有的模型所需的参数:
# =========================================================================== # # Fine-Tuning Flags. # =========================================================================== # # 该部分参数用于微修已有的模型 # 原模型的位置 tf.app.flags.DEFINE_string( 'checkpoint_path', None, 'The path to a checkpoint from which to fine-tune.') tf.app.flags.DEFINE_string( 'checkpoint_model_scope', None, 'Model scope in the checkpoint. None if the same as the trained model.') # 哪些变量不要 tf.app.flags.DEFINE_string( 'checkpoint_exclude_scopes', None, 'Comma-separated list of scopes of variables to exclude when restoring ' 'from a checkpoint.') # 那些变量不训练 tf.app.flags.DEFINE_string( 'trainable_scopes', None, 'Comma-separated list of scopes to filter the set of variables to train.' 'By default, None would train all the variables.') # 忽略丢失的变量 tf.app.flags.DEFINE_boolean( 'ignore_missing_vars', False, 'When restoring a checkpoint would ignore missing variables.') FLAGS = tf.app.flags.FLAGS
所有的参数的意义我都进行了标注,在实际训练的时候需要修改一些参数的内容,这些参数看起来多,其实只是包含了一个网络训练所有必须的部分:
- 网络主体参数;
- 训练时的普通参数(包括训练效果输出、保存方案等);
- 优化器参数;
- 学习率参数;
- 数据集参数;
- 微修已有的模型的参数设置。
2、读取数据集
在训练流程中,其通过如下函数读取数据集
##########################读取数据集部分############################# # 选择数据库 dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
dataset_factory里面放的是数据集获取和处理的函数,这里面对应了4个数据集, 利用datasets_map存储了四个数据集的处理代码。
from __future__ import absolute_import from __future__ import division from __future__ import print_function from datasets import cifar10 from datasets import imagenet from datasets import pascalvoc_2007 from datasets import pascalvoc_2012 datasets_map = { 'cifar10': cifar10, 'imagenet': imagenet, 'pascalvoc_2007': pascalvoc_2007, 'pascalvoc_2012': pascalvoc_2012, } def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): """ 给定一个数据集名和一个拆分名返回一个数据集。 参数: name: String, 数据集名称 split_name: 训练还是测试 dataset_dir: 存储数据集文件的目录。 file_pattern: 用于匹配数据集源文件的文件模式。 reader: tf.readerbase的子类。如果保留为“none”,则使用每个数据集定义的
相关内容
- python神经网络slim常用函数训练保存模型_python_
- Python解决非线性规划中经济调度问题_python_
- Python中图像算术运算的示例详解_python_
- Python+Pygame实现经典魂斗罗游戏_python_
- python人工智能tensorflow函数np.random模块使用_python_
- OpenCV NAO机器人辅助捡球丢球流程分析_python_
- python人工智能tensorflow函数tf.assign使用方法_python_
- Python爬虫获取基金基本信息_python_
- python人工智能tensorflow常见损失函数LOSS汇总_python_
- pytorch部署到jupyter中的问题及解决方案_python_
点击排行
本栏推荐
