Source code for libcity.executor.abstract_tradition_executor

from libcity.executor.abstract_executor import AbstractExecutor
from logging import getLogger
from libcity.utils import get_evaluator, ensure_dir
import numpy as np
import torch
import time
import os


[docs]class AbstractTraditionExecutor(AbstractExecutor): def __init__(self, config, model, data_feature): self.evaluator = get_evaluator(config) self.config = config self.data_feature = data_feature self.device = self.config.get('device', torch.device('cpu')) self.model = model self.exp_id = self.config.get('exp_id', None) self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id) self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(self.exp_id) ensure_dir(self.cache_dir) ensure_dir(self.evaluate_res_dir) self._logger = getLogger() self._scaler = self.data_feature.get('scaler') self.output_dim = self.config.get('output_dim', 1)
[docs] def evaluate(self, test_dataloader): """ use model to test data Args: test_dataloader(torch.Dataloader): Dataloader """ self._logger.info('Start evaluating ...') y_truths = [] y_preds = [] for batch in test_dataloader: batch.to_ndarray() output = self.model.run(batch) y_true = self._scaler.inverse_transform(batch['y'][..., :self.output_dim]) y_pred = self._scaler.inverse_transform(output[..., :self.output_dim]) y_truths.append(y_true) y_preds.append(y_pred) y_preds = np.concatenate(y_preds, axis=0) y_truths = np.concatenate(y_truths, axis=0) # concatenate on batch outputs = {'prediction': y_preds, 'truth': y_truths} filename = \ time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) + '_' \ + self.config['model'] + '_' + self.config['dataset'] + '_predictions.npz' np.savez_compressed(os.path.join(self.evaluate_res_dir, filename), **outputs) self.evaluator.clear() self.evaluator.collect({'y_true': torch.tensor(y_truths), 'y_pred': torch.tensor(y_preds)}) test_result = self.evaluator.save_result(self.evaluate_res_dir) return test_result
[docs] def train(self, train_dataloader, eval_dataloader): """ train model Args: train_dataloader(torch.Dataloader): Dataloader eval_dataloader(torch.Dataloader): Dataloader """ raise NotImplementedError
[docs] def save_model(self, cache_name): """ 对于传统模型,不需要模型保存 Args: cache_name(str): 保存的文件名 """ assert True # do nothing
[docs] def load_model(self, cache_name): """ 对于传统模型,不需要模型加载 Args: cache_name(str): 保存的文件名 """ assert True # do nothing