Source code for libcity.evaluator.cara_loc_pred_evaluator

import os
import json
import time

import numpy as np

from libcity.evaluator.abstract_evaluator import AbstractEvaluator
from libcity.evaluator.eval_funcs import top_k
allowed_metrics = ['Precision', 'Recall', 'F1', 'MRR', 'MAP', 'NDCG']


[docs]class CARALocPredEvaluator(AbstractEvaluator): def __init__(self, config): self.metrics = config['metrics'] # 评估指标, 是一个 list self.config = config self.topk = config['topk'] self.result = {} self.intermediate_result = { 'total': 0, 'hit': 0, 'rank': 0.0, 'dcg': 0.0 } self._check_config() def _check_config(self): if not isinstance(self.metrics, list): raise TypeError('Evaluator type is not list') for i in self.metrics: if i not in allowed_metrics: raise ValueError('the metric is not allowed in \ TrajLocPredEvaluator')
[docs] def collect(self, batch): """ Args: batch (dict): contains three keys: uid, loc_true, and loc_pred. uid (list): 来自于 batch 中的 uid,通过索引可以确定 loc_true 与 loc_pred 中每一行(元素)是哪个用户的一次输入。 loc_true (list): 期望地点(target),来自于 batch 中的 target loc_pred (matrix): 实际上模型的输出,batch_size * output_dim. """ if not isinstance(batch, dict): raise TypeError('evaluator.collect input is not a dict of user') my_true = np.zeros(np.array(batch['loc_true']).shape) hit, rank, dcg = top_k(batch['loc_pred'], my_true, self.topk) total = len(batch['loc_true']) self.intermediate_result['total'] += total self.intermediate_result['hit'] += hit self.intermediate_result['rank'] += rank self.intermediate_result['dcg'] += dcg
[docs] def evaluate(self): precision_key = 'Precision@{}'.format(self.topk) precision = self.intermediate_result['hit'] / ( self.intermediate_result['total'] * self.topk) if 'Precision' in self.metrics: self.result[precision_key] = precision # recall is used to valid in the trainning, so must exit recall_key = 'Recall@{}'.format(self.topk) recall = self.intermediate_result['hit'] \ / self.intermediate_result['total'] self.result[recall_key] = recall if 'F1' in self.metrics: f1_key = 'F1@{}'.format(self.topk) self.result[f1_key] = (2 * precision * recall) / (precision + recall) if 'MRR' in self.metrics: mrr_key = 'MRR@{}'.format(self.topk) self.result[mrr_key] = self.intermediate_result['rank'] \ / self.intermediate_result['total'] if 'MAP' in self.metrics: map_key = 'MAP@{}'.format(self.topk) self.result[map_key] = self.intermediate_result['rank'] \ / self.intermediate_result['total'] if 'NDCG' in self.metrics: ndcg_key = 'NDCG@{}'.format(self.topk) self.result[ndcg_key] = self.intermediate_result['dcg'] \ / self.intermediate_result['total'] return self.result
[docs] def save_result(self, save_path, filename=None): self.evaluate() if not os.path.exists(save_path): os.makedirs(save_path) if filename is None: # 使用时间戳 filename = time.strftime( "%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) print('evaluate result is ', json.dumps(self.result, indent=1)) with open(os.path.join(save_path, '{}.json'.format(filename)), 'w') \ as f: json.dump(self.result, f)
[docs] def clear(self): self.result = {} self.intermediate_result = { 'total': 0, 'hit': 0, 'rank': 0.0, 'dcg': 0.0 }