libcity.executor.eta_executor

class libcity.executor.eta_executor.ETAExecutor(config, model, data_feature)[source]

Bases: libcity.executor.traffic_state_executor.TrafficStateExecutor

_build_train_loss()[source]

根据全局参数`train_loss`选择训练过程的loss函数 如果该参数为none,则需要使用模型自定义的loss函数 注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)

evaluate(test_dataloader)[source]

use model to test data

Parameters

test_dataloader (torch.Dataloader) – Dataloader