Source code for libcity.executor.map_matching_executor

from logging import getLogger
from libcity.executor.abstract_tradition_executor import AbstractTraditionExecutor
from libcity.utils import get_evaluator


[docs]class MapMatchingExecutor(AbstractTraditionExecutor): def __init__(self, config, model, data_feature): self.model = model self.config = config self.evaluator = get_evaluator(config) self.exp_id = self.config.get('exp_id', None) self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id) self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(self.exp_id) self._logger = getLogger()
[docs] def evaluate(self, test_data): """ use model to test data Args: test_data """ result = self.model.run(test_data) batch = {'route': test_data['route'], 'result': result, 'rd_nwk': test_data['rd_nwk']} self.evaluator.collect(batch) self.evaluator.save_result(self.evaluate_res_dir)
[docs] def train(self, train_dataloader, eval_dataloader): """ 对于传统模型,不需要训练 Args: train_dataloader(torch.Dataloader): Dataloader eval_dataloader(torch.Dataloader): Dataloader """ pass # do nothing