libcity.executor.dcrnn_executor

class libcity.executor.dcrnn_executor.DCRNNExecutor(config, model, data_feature)[source]

Bases: libcity.executor.traffic_state_executor.TrafficStateExecutor

_build_train_loss()[source]

根据全局参数`train_loss`选择训练过程的loss函数 如果该参数为none,则需要使用模型自定义的loss函数 注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)

_train_epoch(train_dataloader, epoch_idx, batches_seen=None, loss_func=None)[source]

完成模型一个轮次的训练

Parameters
  • train_dataloader – 训练数据

  • epoch_idx – 轮次数

  • batches_seen – 全局batch数

  • loss_func – 损失函数

Returns

tuple contains

losses(list): 每个batch的损失的数组

batches_seen(int): 全局batch数

Return type

tuple

_valid_epoch(eval_dataloader, epoch_idx, batches_seen=None, loss_func=None)[source]

完成模型一个轮次的评估

Parameters
  • eval_dataloader – 评估数据

  • epoch_idx – 轮次数

  • batches_seen – 全局batch数

  • loss_func – 损失函数

Returns

评估数据的平均损失值

Return type

float

train(train_dataloader, eval_dataloader)[source]

use data to train model with config

Parameters
  • train_dataloader (torch.Dataloader) – Dataloader

  • eval_dataloader (torch.Dataloader) – Dataloader