Source code for libcity.executor.geml_executor

import time
from functools import partial

import numpy as np
import torch
import os

from libcity.executor.traffic_state_executor import TrafficStateExecutor
from libcity.model import loss


[docs]class GEMLExecutor(TrafficStateExecutor): def __init__(self, config, model, data_feature): TrafficStateExecutor.__init__(self, config, model, data_feature) self.loss_p0 = config.get("loss_p0", 0.5) self.loss_p1 = config.get("loss_p1", 0.25) self.loss_p2 = config.get("loss_p2", 0.25) # 只是重载了 predict 的输出读取
[docs] def evaluate(self, test_dataloader): """ use model to test data Args: test_dataloader(torch.Dataloader): Dataloader """ self._logger.info('Start evaluating ...') with torch.no_grad(): self.model.eval() # self.evaluator.clear() y_truths = [] y_preds = [] for batch in test_dataloader: batch.to_tensor(self.device) output, _, _ = self.model.predict(batch) y_true = self._scaler.inverse_transform(batch['y'][..., :self.output_dim]) y_pred = self._scaler.inverse_transform(output[..., :self.output_dim]) y_truths.append(y_true.cpu().numpy()) y_preds.append(y_pred.cpu().numpy()) # evaluate_input = {'y_true': y_true, 'y_pred': y_pred} # self.evaluator.collect(evaluate_input) # self.evaluator.save_result(self.evaluate_res_dir) y_preds = np.concatenate(y_preds, axis=0) y_truths = np.concatenate(y_truths, axis=0) # concatenate on batch outputs = {'prediction': y_preds, 'truth': y_truths} filename = \ time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) + '_' \ + self.config['model'] + '_' + self.config['dataset'] + '_predictions.npz' np.savez_compressed(os.path.join(self.evaluate_res_dir, filename), **outputs) self.evaluator.clear() self.evaluator.collect({'y_true': torch.tensor(y_truths), 'y_pred': torch.tensor(y_preds)}) test_result = self.evaluator.save_result(self.evaluate_res_dir) return test_result
# 只是重载了 predict 的输出读取
[docs] def _build_train_loss(self): """ 根据全局参数`train_loss`选择训练过程的loss函数 如果该参数为none,则需要使用模型自定义的loss函数 注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor) """ if self.train_loss.lower() == 'none': self._logger.warning('Received none train loss func and will use the loss func defined in the model.') return None if self.train_loss.lower() not in ['mae', 'mse', 'rmse', 'mape', 'logcosh', 'huber', 'quantile', 'masked_mae', 'masked_mse', 'masked_rmse', 'masked_mape', 'r2', 'evar']: self._logger.warning('Received unrecognized train loss function, set default mae loss func.') else: self._logger.info('You select `{}` as train loss function.'.format(self.train_loss.lower())) def func(batch): y_true = batch['y'] y_in_true = torch.sum(y_true, dim=-2, keepdim=True) # (B, TO, N, 1) y_out_true = torch.sum(y_true.permute(0, 1, 3, 2, 4), dim=-2, keepdim=True) # (B, TO, N, 1) y_predicted, y_in, y_out = self.model.predict(batch) y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) if self.train_loss.lower() == 'mae': lf = loss.masked_mae_torch elif self.train_loss.lower() == 'mse': lf = loss.masked_mse_torch elif self.train_loss.lower() == 'rmse': lf = loss.masked_rmse_torch elif self.train_loss.lower() == 'mape': lf = loss.masked_mape_torch elif self.train_loss.lower() == 'logcosh': lf = loss.log_cosh_loss elif self.train_loss.lower() == 'huber': lf = loss.huber_loss elif self.train_loss.lower() == 'quantile': lf = loss.quantile_loss elif self.train_loss.lower() == 'masked_mae': lf = partial(loss.masked_mae_torch, null_val=0) elif self.train_loss.lower() == 'masked_mse': lf = partial(loss.masked_mse_torch, null_val=0) elif self.train_loss.lower() == 'masked_rmse': lf = partial(loss.masked_rmse_torch, null_val=0) elif self.train_loss.lower() == 'masked_mape': lf = partial(loss.masked_mape_torch, null_val=0) elif self.train_loss.lower() == 'r2': lf = loss.r2_score_torch elif self.train_loss.lower() == 'evar': lf = loss.explained_variance_score_torch else: lf = loss.masked_mae_torch return self.loss_p0 * lf(y_predicted, y_true) + self.loss_p1 * lf(y_in, y_in_true) + self.loss_p2 * lf(y_out, y_out_true) return func