diff --git a/shared/logger.py b/shared/logger.py index c47a13d..8ffc78f 100644 --- a/shared/logger.py +++ b/shared/logger.py @@ -1,10 +1,9 @@ import json import logging -import re -import uuid +import sys from datetime import datetime -from typing import Callable -from typing import Dict, List, Optional +from logging.handlers import RotatingFileHandler +from typing import List, Optional, Dict, Any from uuid import UUID from pydantic import BaseModel @@ -13,7 +12,7 @@ from .models import LogEntry, LogSeverity, CommandInfo class LogsFilter(BaseModel): - trace_id: Optional[uuid.UUID] + trace_id: Optional[UUID] commands_only: bool severity: List[LogSeverity] @@ -23,171 +22,70 @@ class Paging(BaseModel): page_size: int -class TraceIdProvider: - def __init__(self, get_trace_id: Callable[[], uuid.UUID | None]): - self.get_trace_id: Callable[[], uuid.UUID | None] = get_trace_id +class DBHandler: + def __init__(self, db_config: Dict[str, Any]): + # Initialize database connection here + self.db_config = db_config - def get_current(self) -> uuid.UUID | None: - return self.get_trace_id() + def insert_log(self, log_entry: LogEntry) -> None: + # Insert log entry into the database + pass + + def get_logs(self, filters: LogsFilter, paging: Paging) -> List[LogEntry]: + # Retrieve logs from the database based on filters and paging + pass + + def get_log_entry(self, log_id: int) -> Optional[LogEntry]: + # Retrieve a specific log entry from the database + pass class Logger: - _UNSCOPED_TRACE_ID = uuid.UUID("00000000-0000-0000-0000-000000000000") + _UNSCOPED_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000") - def __init__(self, trace_id_provider: TraceIdProvider): - self._trace_id_provider = trace_id_provider - self._log_file = f"step-ca-webui-{datetime.now().strftime('%Y-%m-%d')}.log" - self._json_file = f"step-ca-webui-{datetime.now().strftime('%Y-%m-%d')}.json" - logging.basicConfig(filename=self._log_file, level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + def __init__(self, db_handler: DBHandler, log_file: str, max_file_size: int, backup_count: int): + self.db_handler = db_handler + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.DEBUG) - def log_scoped( - self, - severity: LogSeverity, - message: str, - command_info: Optional[CommandInfo] = None - ) -> int: - """ - This method logs a message with a trace_id if it exists in the current context. - Context is managed externally, encapsulated in the TraceIdProvider. - """ - trace_id = self._trace_id_provider.get_current() or self._UNSCOPED_TRACE_ID # if trace_id not found/no context + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.DEBUG) + console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console_handler.setFormatter(console_formatter) + self.logger.addHandler(console_handler) + # File handler with rotation + file_handler = RotatingFileHandler(log_file, maxBytes=max_file_size, backupCount=backup_count) + file_handler.setLevel(logging.DEBUG) + file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + file_handler.setFormatter(file_formatter) + self.logger.addHandler(file_handler) + + def log(self, severity: LogSeverity, message: str, trace_id: Optional[UUID] = None, + command_info: Optional[CommandInfo] = None) -> int: log_entry = LogEntry( - entry_id=self._get_next_entry_id(), + entry_id=0, # This will be set by the database timestamp=datetime.now(), severity=severity, message=message, - trace_id=trace_id, + trace_id=trace_id or self._UNSCOPED_TRACE_ID, command_info=command_info ) - self._write(log_entry) - return log_entry.entry_id - def log_unscoped( - self, - severity: LogSeverity, - message: str, - command_info: Optional[CommandInfo] = None - ) -> int: - log_entry = LogEntry( - entry_id=self._get_next_entry_id(), - timestamp=datetime.now(), - severity=severity, - message=message, - trace_id=self._UNSCOPED_TRACE_ID, - command_info=command_info - ) - self._write(log_entry) + # Log to console and file + log_message = f"[TraceID: {log_entry.trace_id}] - {log_entry.message}" + if log_entry.command_info: + log_message += f" - Command: {log_entry.command_info.command}" + self.logger.log(logging.getLevelName(severity.name), log_message) + + # Log to database + self.db_handler.insert_log(log_entry) + return log_entry.entry_id def get_logs(self, filters: LogsFilter, paging: Paging) -> List[LogEntry]: - log_entries = [] - try: - with open(self._json_file, "r") as f: - for line in f: - log_data = json.loads(line) - log_entry = self._create_log_entry_from_data(log_data) - if log_entry and _match_filters(log_entry, filters): - log_entries.append(log_entry) - except FileNotFoundError: - self.log_scoped(LogSeverity.WARNING, f"File not found when trying to read logs: {self._json_file}") - - log_entries.sort(key=lambda x: x.timestamp, reverse=True) - start = (paging.page - 1) * paging.page_size - end = start + paging.page_size - return log_entries[start:end] + return self.db_handler.get_logs(filters, paging) def get_log_entry(self, log_id: int) -> Optional[LogEntry]: - try: - with open(self._json_file, "r") as f: - for line in f: - log_data = json.loads(line) - if log_data["entry_id"] == log_id: - return self._create_log_entry_from_data(log_data) - except FileNotFoundError: - self.log_scoped(LogSeverity.WARNING, f"File not found when trying to read log entry: {self._json_file}") - return None - - def _write(self, log_entry: LogEntry) -> None: - self._write_native(log_entry) - self._write_json(log_entry) - - @staticmethod - def _write_native(log_entry: LogEntry) -> None: - log_message = f"{log_entry.timestamp} - {log_entry.severity.name} - [TraceID: {log_entry.trace_id}] - {log_entry.message}" - if log_entry.command_info: - log_message += f" - Command: {log_entry.command_info.command}" - - logging.log(logging.getLevelName(log_entry.severity.name), log_message) - - def _write_json(self, log_entry: LogEntry) -> None: - json_entry = { - "timestamp": log_entry.timestamp.isoformat(), - "entry_id": log_entry.entry_id, - "severity": log_entry.severity.name, - "trace_id": str(log_entry.trace_id), - "message": _escape_quotes(log_entry.message), - "command_info": { - "command": _escape_quotes(log_entry.command_info.command), - "output": _escape_quotes(log_entry.command_info.output), - "exit_code": log_entry.command_info.exit_code, - "action": log_entry.command_info.action, - } if log_entry.command_info else None - } - - with open(self._json_file, "a") as f: - f.write(json.dumps(json_entry) + "\n") - - def _get_next_entry_id(self) -> int: - try: - with open(self._json_file, "r") as f: - last_line = f.readlines()[-1] - last_entry = json.loads(last_line) - return last_entry["entry_id"] + 1 - except (IndexError, FileNotFoundError, KeyError) as e: - # self.log_scoped(LogSeverity.WARNING, f"Failed to get next entry ID: {e}") - self._write_native(LogEntry( - entry_id=1, - timestamp=datetime.now(), - severity=LogSeverity.WARNING, - message=f"Failed to get next entry ID: {e}", - trace_id=uuid.UUID("00000000-0000-0000-0000-000000000000") - )) - return 1 - - def _create_log_entry_from_data(self, log_data: Dict) -> LogEntry | None: - try: - return LogEntry( - entry_id=log_data["entry_id"], - timestamp=_parse_datetime(log_data["timestamp"]), - severity=LogSeverity[log_data["severity"]], - message=log_data["message"], - trace_id=UUID(log_data["trace_id"]), - command_info=CommandInfo( - command=log_data["command_info"]["command"], - output=log_data["command_info"]["output"], - exit_code=log_data["command_info"]["exit_code"], - action=log_data["command_info"]["action"] - ) if log_data["command_info"] else None - ) - except KeyError: - self.log_scoped(LogSeverity.WARNING, f"Failed to parse log entry: {log_data['entry_id']}", ) - return None - - -def _match_filters(log_entry: LogEntry, filters: LogsFilter) -> bool: - if filters.trace_id and log_entry.trace_id != filters.trace_id: - return False - if filters.commands_only and log_entry.command_info is None: - return False - if filters.severity and log_entry.severity not in filters.severity: - return False - return True - - -def _parse_datetime(timestamp_str: str) -> datetime: - return datetime.fromisoformat(timestamp_str) - - -def _escape_quotes(text: str) -> str: - return re.sub(r'"', r'\\"', text) + return self.db_handler.get_log_entry(log_id)