# -*- coding: utf-8 -*-
import os
import time
import logging
import numpy as np
from rich.table import Table
from torch import Tensor
from . import count_parameters
from rich.console import Console
from rich.logging import RichHandler
[docs]def create_logger(logger_name, log_mode='both', log_file_path=None, encoding='utf-8'):
"""Create logger.
Args:
logger_name (str): Used to name logger.
log_mode (str, optional): Print mode. Options: [file, stdout, both]. Defaults to 'both'.
log_file_path (str, optional): If print output to file, you must specify file path. Defaults to None.
encoding (str, optional): Encoding mode, 'utf-8' for default. Defaults to 'utf-8'.
Returns:
Logger: Logger.
"""
if log_mode != 'stdout' and log_file_path is None:
raise ValueError("log_file_path must be specified when print output to log file!")
logging.root.handlers = []
logging_cfg = {'level': logging.INFO, 'format': '%(message)s'}
h_stdout = RichHandler(show_path=False,
keywords=["Random seed",
"Round", "Epoch", "Loss",
"ACC", "NMI", "ARI", "F1", "HOM", "COM", "PUR", "SC", "GRE",
"Time cost"])
dir_ = os.path.dirname(log_file_path)
if not os.path.exists(dir_):
os.makedirs(dir_)
if log_mode == 'file':
h_file = logging.FileHandler(log_file_path, encoding=encoding)
logging_cfg['handlers'] = [h_file]
elif log_mode == 'stdout':
logging_cfg['handlers'] = [h_stdout]
elif log_mode == 'both':
h_file = logging.FileHandler(log_file_path, encoding=encoding)
logging_cfg['handlers'] = [h_file, h_stdout]
else:
raise ValueError('Print option not supported, options: file, stdout, both')
logging.basicConfig(**logging_cfg)
return Logger(name=logger_name)
[docs]class Logger(object):
"""Logger.
Args:
name (str): Name of logger.
"""
def __init__(self, name):
self.logger = logging.getLogger(name)
[docs] def info(self, message):
"""Info level log.
Args:
message (str): Log message.
"""
self.logger.info(message)
[docs] def error(self, message):
"""Error level log.
Args:
message (str): Log message.
"""
self.logger.error(message)
[docs] def debug(self, message):
"""Debug level log.
Args:
message (str): Log message.
"""
self.logger.debug(message)
[docs] def warning(self, message):
"""Warning level log.
Args:
message (str): Log message.
"""
self.logger.warning(message)
[docs] def flag(self, message):
"""Print flag to partition different parts above and below.
Args:
message (str): Log message.
"""
self.logger.info(f"{'*' * 40}{message}{'*' * 40}")
[docs] @staticmethod
def table(results_dir: str, dataset_name: str, results_dict: dict, decimal: int = 4):
"""Create table.
Args:
results_dir (str): Results directory.
dataset_name (str): Dataset name.
results_dict (dict): Results dictionary.
decimal (int, optional): Decimal. Defaults to 4.
"""
table = Table(title=f"Clustering Results on Dataset {dataset_name}")
if type(next(iter(results_dict.values()))) in [float, int, np.float32, np.float64, np.int32, np.int64]:
table.add_column("Metric", justify="right", style="cyan", no_wrap=True)
table.add_column("Value", justify="right", style="green", no_wrap=True)
for key, value in results_dict.items():
table.add_row(key, str(round(value, decimal)))
else:
table.add_column("Metric", justify="right", style="cyan", no_wrap=True)
rounds = len(next(iter(results_dict.values())))
for i in range(rounds):
table.add_column(f"{i + 1}", justify="right", style="green", no_wrap=True)
table.add_column("Avg.", justify="right", style="green", no_wrap=True)
table.add_column("Std.", justify="right", style="green", no_wrap=True)
for key, values in results_dict.items():
table.add_row(key, *[str(round(value, decimal)) for value in values],
str(round(np.mean(values), decimal)),
str(round(np.std(values), decimal)))
with open(os.path.join(results_dir, "results.txt"), "a+") as report_file:
console = Console(file=report_file)
console.print(get_formatted_time())
console.print(table)
console = Console()
console.print(table)
[docs] def loss(self, epoch, loss, decimal: int = 6):
"""Loss level log.
Args:
epoch (int): Epoch.
loss (float): Loss.
decimal (int, optional): Decimal. Defaults to 6.
"""
if isinstance(loss, Tensor):
loss = loss.item()
self.logger.info(f"Epoch: {epoch:0>4d}, Loss: {round(loss, decimal):0>.{decimal}f}")
[docs] def model_info(self, model):
"""Model info level log.
Args:
model (nn.Module): Model.
"""
self.logger.info(model)
self.logger.info(f"Parameters: {count_parameters(model)} MB")