libcity.executor.traffic_state_executor¶
-
class
libcity.executor.traffic_state_executor.
TrafficStateExecutor
(config, model, data_feature)[source]¶ Bases:
libcity.executor.abstract_executor.AbstractExecutor
-
_build_train_loss
()[source]¶ 根据全局参数`train_loss`选择训练过程的loss函数 如果该参数为none,则需要使用模型自定义的loss函数 注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)
-
_train_epoch
(train_dataloader, epoch_idx, loss_func=None)[source]¶ 完成模型一个轮次的训练
- Parameters
train_dataloader – 训练数据
epoch_idx – 轮次数
loss_func – 损失函数
- Returns
每个batch的损失的数组
- Return type
list
-
_valid_epoch
(eval_dataloader, epoch_idx, loss_func=None)[source]¶ 完成模型一个轮次的评估
- Parameters
eval_dataloader – 评估数据
epoch_idx – 轮次数
loss_func – 损失函数
- Returns
评估数据的平均损失值
- Return type
float
-