Files
openlan-cgw/utils/psql_client/psql_client.py
2024-11-25 17:06:27 +02:00

127 lines
4.1 KiB
Python

import psycopg2
from psycopg2 import OperationalError, sql
from typing import List, Tuple
class PostgreSQLClient:
def __init__(self, host: str, port: int, database: str, user: str, password: str):
"""Initialize the PostgreSQL client with the connection parameters."""
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.connection = None
self.cursor = None
def connect(self):
"""Connect to the PostgreSQL database."""
try:
self.connection = psycopg2.connect(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
port=self.port
)
self.cursor = self.connection.cursor()
print("Connection successful")
except OperationalError as e:
print(f"Error: Unable to connect to the database. {e}")
def execute_query(self, query: str, params=None):
"""Execute a single query (SELECT, INSERT, UPDATE, DELETE, etc.)."""
if not self.cursor:
print("Error: No database connection established.")
return None
try:
# Use sql.SQL for parameterized queries to avoid SQL injection
if params:
self.cursor.execute(sql.SQL(query), params)
else:
self.cursor.execute(query)
self.connection.commit()
print("Query executed successfully")
except Exception as e:
print(f"Error executing query: {e}")
self.connection.rollback()
return None
def fetchone(self):
"""Fetch one row from the last executed query (used with SELECT)."""
result = None
if self.cursor:
try:
result = self.cursor.fetchone()
except Exception as e:
print(f"Error executing fetchone: {e}")
else:
print("Error: No database connection or query executed.")
return result
def fetchall(self):
"""Fetch all rows from the last executed query (used with SELECT)."""
result = None
if self.cursor:
try:
result = self.cursor.fetchall()
except Exception as e:
print(f"Error executing fetchone: {e}")
else:
print("Error: No database connection or query executed.")
return result
def get_infrastructure_group(self, group_id: int) -> Tuple[int, int, int]:
"""Fetch group record by group id."""
group_info = tuple()
self.execute_query(f"select * from infrastructure_groups WHERE id = {group_id};")
group_info = self.fetchone()
return group_info
def get_all_infrastructure_groups(self) -> List[Tuple[int, int, int]]:
"""Fetch group record by group id."""
group_list = list()
self.execute_query(f"select * from infrastructure_groups;")
group_list = self.fetchall()
return group_list
def get_infra(self, mac: str) -> Tuple[str, int]:
"""Fetch group record by infra mac."""
infra_info = None
self.execute_query(f"select * from infras WHERE mac = \'{mac}\';")
infra_info = self.fetchone()
# change mac format from "XX:XX:XX:XX:XX:XX" to "XX-XX-XX-XX-XX-XX"
if infra_info:
temp_infra = list(infra_info)
temp_infra[0] = temp_infra[0].replace(":", "-", 5)
infra_info = tuple(temp_infra)
return infra_info
def get_infras_by_group_id(self, group_id) -> List[Tuple[str, int]]:
"""Fetch group record by infra mac."""
infras_info = None
self.execute_query(f"select * from infras WHERE infra_group_id = \'{group_id}\';")
infras_info = self.fetchall()
return infras_info
def disconnect(self):
"""Close the cursor and connection."""
if self.cursor:
self.cursor.close()
if self.connection:
self.connection.close()
print("Connection closed.")