import os from typing import List, Optional from pydantic import BaseModel, Field from sqlalchemy import ( create_engine, Column, Integer, String, DateTime, Enum, JSON, Sequence, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.sql import func from shared.db_logger_interface import IDBLogger from shared.models import LogEntry, LogsFilter, Paging, LogSeverity, CommandInfo _Base = declarative_base() class LogEntryModel(_Base): __tablename__ = "log_entries" id = Column(Integer, Sequence("log_entry_id_seq"), primary_key=True) timestamp = Column(DateTime(timezone=True), server_default=func.now()) severity = Column(Enum(LogSeverity)) message = Column(String) trace_id = Column(String) command_info = Column(JSON) class DbConnectionModel(BaseModel): DB_HOST: str = Field(..., min_length=1) DB_PORT: int = Field(..., gt=0, lt=65536) DB_NAME: str = Field(..., min_length=1) DB_USER: str = Field(..., min_length=1) DB_PASSWORD: str = Field(..., min_length=1) class DBLogger(IDBLogger): """ DBLogger class for managing database operations related to logging using SQLAlchemy. Required environment variables: - DB_HOST: The hostname of the database server - DB_PORT: The port number of the database server - DB_NAME: The name of the database - DB_USER: The username for database authentication - DB_PASSWORD: The password for database authentication Make sure to set these variables in your .env file or environment. """ def __init__(self, is_test: bool = False): if is_test: return connection = DbConnectionModel( DB_HOST=os.getenv("DB_HOST"), DB_PORT=os.getenv("DB_PORT"), DB_NAME=os.getenv("DB_NAME"), DB_USER=os.getenv("DB_USER"), DB_PASSWORD=os.getenv("DB_PASSWORD"), ) url = ( "postgresql://" + f"{connection.DB_USER}:{connection.DB_PASSWORD}" + f"@{connection.DB_HOST}:{connection.DB_PORT}/{connection.DB_NAME}" ) self.engine = create_engine(url) self.Session = sessionmaker(bind=self.engine) _Base.metadata.create_all(self.engine) def insert_log(self, log_entry: LogEntry) -> None: with self.Session() as session: db_log_entry = LogEntryModel( timestamp=log_entry.timestamp, severity=log_entry.severity, message=log_entry.message, trace_id=str(log_entry.trace_id), command_info=( log_entry.command_info.model_dump() if log_entry.command_info else None ), ) session.add(db_log_entry) session.commit() def get_logs(self, filters: LogsFilter, paging: Paging) -> List[LogEntry]: with self.Session() as session: query = session.query(LogEntryModel) if filters.trace_id: query = query.filter(LogEntryModel.trace_id == filters.trace_id) if filters.commands_only: query = query.filter(LogEntryModel.command_info.isnot(None)) query = query.filter( LogEntryModel.severity.in_([s.value for s in filters.severity]) ) query = query.order_by(LogEntryModel.timestamp.desc()) query = query.limit(paging.page_size).offset( (paging.page - 1) * paging.page_size ) results = query.all() return [ LogEntry( entry_id=result.id, timestamp=result.timestamp, severity=result.severity, message=result.message, trace_id=result.trace_id, command_info=( CommandInfo.model_validate(result.command_info) if result.command_info else None ), ) for result in results ] def get_log_entry(self, log_id: int) -> Optional[LogEntry]: with self.Session() as session: result = ( session.query(LogEntryModel).filter(LogEntryModel.id == log_id).first() ) if result: return LogEntry( entry_id=result.id, timestamp=result.timestamp, severity=result.severity, message=result.message, trace_id=result.trace_id, command_info=( CommandInfo.model_validate(result.command_info) if result.command_info else None ), ) return None def get_next_entry_id(self) -> int: with self.Session() as session: # TODO: test in real postgres return session.execute(Sequence("log_entry_id_seq")).scalar()