libcity.executor.mtgnn_executor¶
-
class
libcity.executor.mtgnn_executor.
MTGNNExecutor
(config, model, data_feature)[source]¶ Bases:
libcity.executor.traffic_state_executor.TrafficStateExecutor
-
_build_train_loss
()[source]¶ 根据全局参数`train_loss`选择训练过程的loss函数 如果该参数为none,则需要使用模型自定义的loss函数 注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)
-
_train_epoch
(train_dataloader, epoch_idx, batches_seen=None, loss_func=None)[source]¶ 完成模型一个轮次的训练
- Parameters
train_dataloader – 训练数据
epoch_idx – 轮次数
batches_seen – 全局batch数
loss_func – 损失函数
- Returns
- tuple contains
losses(list): 每个batch的损失的数组
batches_seen(int): 全局batch数
- Return type
tuple
-