diff --git a/shared/db_handler.py b/shared/db_logger.py similarity index 92% rename from shared/db_handler.py rename to shared/db_logger.py index 3339d37..b515f6d 100644 --- a/shared/db_handler.py +++ b/shared/db_logger.py @@ -1,14 +1,16 @@ import os -from typing import List, Optional, Dict, Any +from typing import List, Optional + import psycopg2 from psycopg2.extras import RealDictCursor -from .models import LogEntry, LogSeverity, CommandInfo, LogsFilter, Paging +from .logger import IDBLogger +from .models import LogEntry, LogsFilter, Paging -class DBHandler: +class DBLogger(IDBLogger): """ - DBHandler class for managing database operations related to logging. + DBLogger class for managing database operations related to logging. Required environment variables: - DB_HOST: The hostname of the database server diff --git a/shared/logger.py b/shared/logger.py index 359aa9a..89412c4 100644 --- a/shared/logger.py +++ b/shared/logger.py @@ -2,11 +2,10 @@ import logging import sys from datetime import datetime from logging.handlers import RotatingFileHandler -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Protocol from uuid import UUID from .models import LogEntry, LogSeverity, CommandInfo, LogsFilter, Paging -from .db_handler import DBHandler class TraceIdProvider: @@ -17,6 +16,20 @@ class TraceIdProvider: return self.get_trace_id() +class IDBLogger(Protocol): + def insert_log(self, log_entry: LogEntry) -> None: + ... + + def get_logs(self, filters: LogsFilter, paging: Paging) -> List[LogEntry]: + ... + + def get_log_entry(self, log_id: int) -> Optional[LogEntry]: + ... + + def get_next_entry_id(self) -> int: + ... + + class Logger: _UNSCOPED_TRACE_ID = UUID("00000000-0000-0000-0000-000000000000") @@ -25,8 +38,8 @@ class Logger: BACKUP_COUNT = 5 LOGLEVEL = logging.DEBUG - def __init__(self, trace_id_provider: TraceIdProvider): - self.db_handler = DBHandler() + def __init__(self, trace_id_provider: TraceIdProvider, db_handler: IDBLogger) -> None: + self.db_logger = db_handler self.trace_id_provider = trace_id_provider self.logger = logging.getLogger(__name__) self.logger.setLevel(self.LOGLEVEL) @@ -48,7 +61,7 @@ class Logger: command_info: Optional[CommandInfo] = None ) -> int: log_entry = LogEntry( - entry_id=self.db_handler.get_next_entry_id(), + entry_id=self.db_logger.get_next_entry_id(), timestamp=datetime.now(), severity=severity, message=message, @@ -63,12 +76,12 @@ class Logger: self.logger.log(logging.getLevelName(severity.name), log_message) # Log to database - self.db_handler.insert_log(log_entry) + self.db_logger.insert_log(log_entry) return log_entry.entry_id def get_logs(self, filters: LogsFilter, paging: Paging) -> List[LogEntry]: - return self.db_handler.get_logs(filters, paging) + return self.db_logger.get_logs(filters, paging) def get_log_entry(self, log_id: int) -> Optional[LogEntry]: - return self.db_handler.get_log_entry(log_id) + return self.db_logger.get_log_entry(log_id)