logging implemented

This commit is contained in:
stone-w4tch3r
2024-08-21 19:20:00 +05:00
parent 7d5c8a6aa2
commit 17c90463fc
5 changed files with 161 additions and 79 deletions

View File

@@ -7,16 +7,16 @@ from fastapi.responses import PlainTextResponse
from core.certificate_manager import CertificateManager
from core.trace_id_handler import TraceIdHandler
from shared.api_models import (
Certificate,
CertificateDTO,
CertificateGenerateRequest,
CertificateGenerateResult,
CertificateRenewResult,
CertificateRevokeResult,
CommandPreview,
LogEntry,
CommandPreviewDTO,
LogEntryDTO,
LogsRequest
)
from shared.logger import Logger
from shared.logger import Logger, LogsFilter, Paging
from shared.models import LogSeverity
@@ -45,15 +45,15 @@ class APIServer:
uvicorn.run(self._app, host="0.0.0.0", port=self._port)
def _setup_routes(self):
@self._app.get("/certificates", response_model=Union[List[Certificate], CommandPreview])
async def list_certificates(preview: bool = Query(...)) -> Union[List[Certificate], CommandPreview]:
@self._app.get("/certificates", response_model=Union[List[CertificateDTO], CommandPreviewDTO])
async def list_certificates(preview: bool = Query(...)) -> Union[List[CertificateDTO], CommandPreviewDTO]:
if preview:
command = self._cert_manager.preview_list_certificates()
return CommandPreview(command=command)
return CommandPreviewDTO(command=command)
certs = self._cert_manager.list_certificates()
return [
Certificate(
CertificateDTO(
id=cert.id,
name=cert.name,
status=cert.status,
@@ -61,14 +61,14 @@ class APIServer:
) for cert in certs
]
@self._app.post("/certificates/generate", response_model=Union[CertificateGenerateResult, CommandPreview])
@self._app.post("/certificates/generate", response_model=Union[CertificateGenerateResult, CommandPreviewDTO])
async def generate_certificate(
cert_request: CertificateGenerateRequest,
preview: bool = Query(...)
):
if preview:
command = self._cert_manager.preview_generate_certificate(cert_request.keyName, cert_request.keyType, cert_request.duration)
return CommandPreview(command=command)
return CommandPreviewDTO(command=command)
cert = self._cert_manager.generate_certificate(cert_request.keyName, cert_request.keyType, cert_request.duration)
return CertificateGenerateResult(
@@ -80,7 +80,7 @@ class APIServer:
expirationDate=cert.expiration_date
)
@self._app.post("/certificates/renew", response_model=Union[CertificateRenewResult, CommandPreview])
@self._app.post("/certificates/renew", response_model=Union[CertificateRenewResult, CommandPreviewDTO])
async def renew_certificate(
certId: str = Query(...),
duration: int = Query(..., description="Duration in seconds"),
@@ -88,7 +88,7 @@ class APIServer:
):
if preview:
command = self._cert_manager.preview_renew_certificate(certId, duration)
return CommandPreview(command=command)
return CommandPreviewDTO(command=command)
cert = self._cert_manager.renew_certificate(certId, duration)
return CertificateRenewResult(
@@ -99,14 +99,14 @@ class APIServer:
newExpirationDate=cert.new_expiration_date
)
@self._app.post("/certificates/revoke", response_model=Union[CertificateRevokeResult, CommandPreview])
@self._app.post("/certificates/revoke", response_model=Union[CertificateRevokeResult, CommandPreviewDTO])
async def revoke_certificate(
certId: str = Query(...),
preview: bool = Query(...)
):
if preview:
command = self._cert_manager.preview_revoke_certificate(certId)
return CommandPreview(command=command)
return CommandPreviewDTO(command=command)
cert = self._cert_manager.revoke_certificate(certId)
return CertificateRevokeResult(
@@ -117,13 +117,13 @@ class APIServer:
revocationDate=cert.revocation_date
)
@self._app.get("/logs/single", response_model=LogEntry)
@self._app.get("/logs/single", response_model=LogEntryDTO)
async def get_log_entry(logId: int = Query(..., gt=0)):
log_entry = self._logger.get_log_entry(logId)
if not log_entry:
raise HTTPException(status_code=404, detail="Log entry not found")
return LogEntry(
return LogEntryDTO(
entryId=log_entry.entry_id,
timestamp=log_entry.timestamp,
severity=log_entry.severity,
@@ -132,12 +132,15 @@ class APIServer:
commandInfo=log_entry.command_info
)
@self._app.post("/logs", response_model=List[LogEntry])
async def get_logs(logs_request: LogsRequest) -> List[LogEntry]:
logs = self._logger.get_logs(logs_request.model_dump())
@self._app.post("/logs", response_model=List[LogEntryDTO])
async def get_logs(logs_request: LogsRequest) -> List[LogEntryDTO]:
logs = self._logger.get_logs(
LogsFilter(trace_id=logs_request.traceId, commands_only=logs_request.commandsOnly, severity=logs_request.severity),
Paging(page=logs_request.page, page_size=logs_request.pageSize)
)
return [
LogEntry(
LogEntryDTO(
entryId=log.entry_id,
timestamp=log.timestamp,
severity=log.severity,

View File

@@ -3,7 +3,7 @@ from certificate_manager import CertificateManager
from shared.logger import Logger
if __name__ == "__main__":
logger = Logger("step-ca-webui.log")
logger = Logger()
certificate_manager = CertificateManager(logger)
api_server = APIServer(certificate_manager, logger, "0.0.1", 5000)
api_server.run()

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
from shared.models import LogSeverity, KeyType
class Certificate(BaseModel):
class CertificateDTO(BaseModel):
id: str
name: str
status: str
@@ -20,7 +20,7 @@ class CertificateGenerateRequest(BaseModel):
duration: int = Field(..., gt=0, description="Duration in seconds")
class CommandPreview(BaseModel):
class CommandPreviewDTO(BaseModel):
command: str
@@ -49,20 +49,20 @@ class CertificateRevokeResult(BaseModel):
revocationDate: datetime
class CommandInfo(BaseModel):
class CommandInfoDTO(BaseModel):
command: str
output: str
exitCode: int
action: str
class LogEntry(BaseModel):
class LogEntryDTO(BaseModel):
entryId: int = Field(..., gt=0)
timestamp: datetime
severity: LogSeverity
message: str
traceId: uuid.UUID
commandInfo: Optional[CommandInfo]
commandInfo: Optional[CommandInfoDTO]
class LogsRequest(BaseModel):

View File

@@ -1,5 +1,8 @@
import json
import logging
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional
from uuid import UUID
@@ -8,10 +11,24 @@ from core.trace_id_handler import TraceIdHandler
from .models import LogEntry, LogSeverity, CommandInfo
@dataclass
class LogsFilter:
trace_id: Optional[uuid.UUID]
commands_only: bool
severity: List[LogSeverity]
@dataclass
class Paging:
page: int
page_size: int
class Logger:
def __init__(self, log_file: str):
self.log_file = log_file
logging.basicConfig(filename=log_file, level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
def __init__(self):
self._log_file = f"step-ca-webui-{datetime.now().strftime('%Y-%m-%d')}.log"
self._json_file = f"step-ca-webui-{datetime.now().strftime('%Y-%m-%d')}.json"
logging.basicConfig(filename=self._log_file, level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
def log_scoped(
self,
@@ -19,8 +36,7 @@ class Logger:
message: str,
command_info: Optional[CommandInfo] = None
) -> int:
# if trace_id not found/no context
trace_id = TraceIdHandler.get_current_trace_id() or uuid.UUID("00000000-0000-0000-0000-000000000000")
trace_id = TraceIdHandler.get_current_trace_id() or uuid.UUID("00000000-0000-0000-0000-000000000000") # if trace_id not found/no context
log_entry = LogEntry(
entry_id=self._get_next_entry_id(),
@@ -51,56 +67,111 @@ class Logger:
self._write_log_entry(log_entry)
return log_entry.entry_id
def get_logs(self, filters: Dict) -> List[LogEntry]:
# Implementation for retrieving logs based on filters
# This is a placeholder and would need to be implemented based on your storage mechanism
return [
LogEntry(
entry_id=1,
timestamp=datetime.now(),
severity=LogSeverity.INFO,
message="This is a log entry",
trace_id=UUID("123e4567-e89b-12d3-a456-426614174000")
),
LogEntry(
entry_id=2,
timestamp=datetime.now(),
severity=LogSeverity.ERROR,
message="This is an error log entry",
trace_id=UUID("123e4567-e89b-12d3-a456-426614174000")
)
]
raise NotImplementedError
def get_logs(self, filters: LogsFilter, paging: Paging) -> List[LogEntry]:
log_entries = []
try:
with open(self._json_file, "r") as f:
for line in f:
log_data = json.loads(line)
log_entry = self._create_log_entry_from_data(log_data)
if log_entry and _match_filters(log_entry, filters):
log_entries.append(log_entry)
except FileNotFoundError:
self.log_scoped(LogSeverity.WARNING, f"File not found when trying to read logs: {self._json_file}")
log_entries.sort(key=lambda x: x.timestamp, reverse=True)
start = (paging.page - 1) * paging.page_size
end = start + paging.page_size
return log_entries[start:end]
def get_log_entry(self, log_id: int) -> Optional[LogEntry]:
# Implementation for retrieving a single log entry
# This is a placeholder and would need to be implemented based on your storage mechanism
return LogEntry(
entry_id=log_id,
timestamp=datetime.now(),
severity=LogSeverity.INFO,
message="This is a log entry",
trace_id=UUID("123e4567-e89b-12d3-a456-426614174000")
)
raise NotImplementedError
try:
with open(self._json_file, "r") as f:
for line in f:
log_data = json.loads(line)
if log_data["entry_id"] == log_id:
return self._create_log_entry_from_data(log_data)
except FileNotFoundError:
self.log_scoped(LogSeverity.WARNING, f"File not found when trying to read log entry: {self._json_file}")
return None
def _write_log_entry(self, log_entry: LogEntry) -> None:
log_message = f"{log_entry.timestamp} - {log_entry.severity.name} - {log_entry.message} - Trace ID: {log_entry.trace_id}"
log_message = f"{log_entry.timestamp} - {log_entry.severity.name} - [TraceID: {log_entry.trace_id}] - {log_entry.message}"
if log_entry.command_info:
log_message += f" - Command: {log_entry.command_info.command}"
logging.log(self._get_logging_level(log_entry.severity), log_message)
logging.log(logging.getLevelName(log_entry.severity.name), log_message)
self._write_json_log_entry(log_entry)
def _write_json_log_entry(self, log_entry: LogEntry) -> None:
json_entry = {
"timestamp": log_entry.timestamp.isoformat(),
"entry_id": _with_leading_zeros(log_entry.entry_id),
"severity": log_entry.severity.name,
"trace_id": str(log_entry.trace_id),
"message": _escape_quotes(log_entry.message),
"command_info": {
"command": _escape_quotes(log_entry.command_info.command),
"output": _escape_quotes(log_entry.command_info.output),
"exit_code": log_entry.command_info.exit_code,
"action": log_entry.command_info.action,
} if log_entry.command_info else None
}
with open(self._json_file, "a") as f:
f.write(json.dumps(json_entry) + "\n")
def _get_next_entry_id(self) -> int:
# Implementation for generating the next entry ID
# This is a placeholder and would need to be implemented based on your storage mechanism
return 1
raise NotImplementedError
try:
with open(self._json_file, "r") as f:
last_line = f.readlines()[-1]
last_entry = json.loads(last_line)
return last_entry["entry_id"] + 1
except (IndexError, FileNotFoundError, KeyError) as e:
self.log_scoped(LogSeverity.WARNING, f"Failed to get next entry ID: {e}")
return 1
@staticmethod
def _get_logging_level(severity: LogSeverity) -> int: # TODO simplify
return {
LogSeverity.DEBUG: logging.DEBUG,
LogSeverity.INFO: logging.INFO,
LogSeverity.WARN: logging.WARNING,
LogSeverity.ERROR: logging.ERROR
}[severity]
def _create_log_entry_from_data(self, log_data: Dict) -> LogEntry | None:
try:
return LogEntry(
entry_id=log_data["entry_id"],
timestamp=_parse_datetime(log_data["timestamp"]),
severity=LogSeverity[log_data["severity"]],
message=log_data["message"],
trace_id=UUID(log_data["trace_id"]),
command_info=CommandInfo(
command=log_data["command_info"]["command"],
output=log_data["command_info"]["output"],
exit_code=log_data["command_info"]["exit_code"],
action=log_data["command_info"]["action"]
) if log_data["command_info"] else None
)
except KeyError:
self.log_scoped(LogSeverity.WARNING, f"Failed to parse log entry: {log_data['entry_id']}", )
return None
def _match_filters(log_entry: LogEntry, filters: LogsFilter) -> bool:
if filters.trace_id and log_entry.trace_id != filters.trace_id:
return False
if filters.commands_only and log_entry.command_info is None:
return False
if filters.severity and log_entry.severity not in filters.severity:
return False
return True
def _parse_datetime(timestamp_str: str) -> datetime:
return datetime.fromisoformat(timestamp_str)
def _escape_quotes(text: str) -> str:
return re.sub(r'"', r'\\"', text)
def _with_leading_zeros(number: int) -> str:
length = 10 \
if str(number).__len__() < 10 \
else str(number).__len__()
return str(number).zfill(length)

View File

@@ -1,21 +1,29 @@
import enum
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from typing import Optional, List
from uuid import UUID
class LogSeverity(enum.StrEnum):
DEBUG = "DEBUG"
INFO = "INFO"
WARN = "WARN"
WARNING = "WARN"
ERROR = "ERROR"
@staticmethod
def as_list() -> List[str]: # TODO use
return [s.upper() for s in LogSeverity]
class KeyType(enum.StrEnum):
RSA = "RSA"
ECDSA = "ECDSA"
@staticmethod
def as_list() -> List[str]: # TODO use
return [s.upper() for s in KeyType]
@dataclass
class CommandInfo:
@@ -27,9 +35,9 @@ class CommandInfo:
@dataclass
class LogEntry:
entry_id: int
timestamp: datetime
entry_id: int
severity: LogSeverity
message: str
trace_id: UUID
message: str
command_info: Optional[CommandInfo] = None