libcity.executor.traffic_state_executor

class libcity.executor.traffic_state_executor.TrafficStateExecutor(config, model, data_feature)[source]

Bases: libcity.executor.abstract_executor.AbstractExecutor

_build_lr_scheduler()[source]

根据全局参数`lr_scheduler`选择对应的lr_scheduler

_build_optimizer()[source]

根据全局参数`learner`选择optimizer

_build_train_loss()[source]

根据全局参数`train_loss`选择训练过程的loss函数 如果该参数为none,则需要使用模型自定义的loss函数 注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)

_train_epoch(train_dataloader, epoch_idx, loss_func=None)[source]

完成模型一个轮次的训练

Parameters
  • train_dataloader – 训练数据

  • epoch_idx – 轮次数

  • loss_func – 损失函数

Returns

每个batch的损失的数组

Return type

list

_valid_epoch(eval_dataloader, epoch_idx, loss_func=None)[source]

完成模型一个轮次的评估

Parameters
  • eval_dataloader – 评估数据

  • epoch_idx – 轮次数

  • loss_func – 损失函数

Returns

评估数据的平均损失值

Return type

float

evaluate(test_dataloader)[source]

use model to test data

Parameters

test_dataloader (torch.Dataloader) – Dataloader

load_model(cache_name)[source]

加载对应模型的 cache

Parameters

cache_name (str) – 保存的文件名

load_model_with_epoch(epoch)[source]

加载某个epoch的模型

Parameters

epoch (int) – 轮数

save_model(cache_name)[source]

将当前的模型保存到文件

Parameters

cache_name (str) – 保存的文件名

save_model_with_epoch(epoch)[source]

保存某个epoch的模型

Parameters

epoch (int) – 轮数

train(train_dataloader, eval_dataloader)[source]

use data to train model with config

Parameters
  • train_dataloader (torch.Dataloader) – Dataloader

  • eval_dataloader (torch.Dataloader) – Dataloader