diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d6b5bd3 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "ucentral-cgw" +version = "0.0.1" +edition = "2021" + +[dependencies] +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0.85" +env_logger = "0.5.0" +log = "0.4.20" +tokio = { version = "1.34.0", features = ["full"] } +tokio-stream = { version = "*", features = ["full"] } +tokio-tungstenite = { version = "*", features = ["native-tls"] } +tokio-native-tls = "*" +tokio-rustls = "*" +tokio-postgres = { version = "0.7.10", features = ["with-eui48-1"]} +tokio-pg-mapper = "0.2.0" +tungstenite = { version = "*"} +mio = "0.6.10" +native-tls = "*" +futures-util = { version = "0.3.0", default-features = false } +futures-channel = "0.3.0" +futures-executor = { version = "0.3.0", optional = true } +futures = "0.3.0" +rlimit = "0.10.1" +tonic = "0.10.2" +prost = "0.12" +rdkafka = "0.35.0" +clap = { version = "4.4.8", features = ["derive"] } +eui48 = "1.1.0" +uuid = { version = "1.6.1", features = ["serde"] } +redis-async = "0.16.1" + +[build-dependencies] +tonic-build = "0.10" +prost-build = "0.12" diff --git a/README.md b/README.md new file mode 100644 index 0000000..035453e --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# openlan-cgw diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..9cef20e --- /dev/null +++ b/build.rs @@ -0,0 +1,4 @@ +fn main() -> Result<(), Box> { + tonic_build::compile_protos("src/proto/cgw.proto")?; + Ok(()) +} diff --git a/src/cgw_connection_processor.rs b/src/cgw_connection_processor.rs new file mode 100644 index 0000000..821b276 --- /dev/null +++ b/src/cgw_connection_processor.rs @@ -0,0 +1,531 @@ +use crate::cgw_connection_server::{ + CGWConnectionServer, + CGWConnectionServerReqMsg, +}; + +use tokio::{ + sync::{ + mpsc::{ + UnboundedReceiver, + unbounded_channel, + }, + }, + net::{ + TcpStream, + }, + time::{ + sleep, + Duration, + Instant, + }, +}; +use tokio_native_tls::TlsStream; +use tokio_tungstenite::{ + WebSocketStream, + tungstenite::{ + protocol::{ + Message, + }, + }, +}; +use tungstenite::Message::{ + Close, + Text, + Ping, +}; +use futures_util::{ + StreamExt, + SinkExt, + FutureExt, + stream::{ + SplitSink, + SplitStream + }, +}; +use std::{ + net::SocketAddr, + sync::{ + Arc, + }, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +type CGWUcentralJRPCMessage = Map; +type SStream = SplitStream>>; +type SSink = SplitSink>, Message>; + +#[derive(Debug)] +pub enum CGWConnectionProcessorReqMsg { + // We got green light from server to process this connection on + AddNewConnectionAck, + AddNewConnectionShouldClose, + SinkRequestToDevice(String, String), +} + +#[derive(Debug)] +enum CGWConnectionState { + IsActive, + IsForcedToClose, + IsDead, + IsStale, + ClosedGracefully, +} + +#[derive(Deserialize, Serialize, Debug, Default)] +struct CGWEventLogParams { + serial: String, + log: String, + severity: i64, +} + +#[derive(Deserialize, Serialize, Debug, Default)] +struct CGWEventLog { + params: CGWEventLogParams, +} + +#[derive(Deserialize, Serialize, Debug, Default)] +struct CGWEventConnectParamsCaps { + compatible: String, + model: String, + platform: String, + label_macaddr: String, +} + +#[derive(Deserialize, Serialize, Debug, Default)] +struct CGWEventConnectParams { + serial: String, + firmware: String, + uuid: u64, + capabilities: CGWEventConnectParamsCaps, +} + +#[derive(Deserialize, Serialize, Debug, Default)] +struct CGWEventConnect { + params: CGWEventConnectParams, +} + +#[derive(Deserialize, Serialize, Debug)] +enum CGWEvent { + Connect(CGWEventConnect), + Log(CGWEventLog), + Empty, +} + +fn cgw_parse_jrpc_event(map: &Map, method: String) -> CGWEvent { + if method == "log" { + let params = map.get("params").expect("Params are missing"); + return CGWEvent::Log(CGWEventLog { + params: CGWEventLogParams { + serial: params["serial"].to_string(), + log: params["log"].to_string(), + severity: serde_json::from_value(params["severity"].clone()).unwrap(), + }, + }); + } else if method == "connect" { + let params = map.get("params").expect("Params are missing"); + return CGWEvent::Connect(CGWEventConnect { + params: CGWEventConnectParams { + serial: params["serial"].to_string(), + firmware: params["firmware"].to_string(), + uuid: 1, + capabilities: CGWEventConnectParamsCaps { + compatible: params["capabilities"]["compatible"].to_string(), + model: params["capabilities"]["model"].to_string(), + platform: params["capabilities"]["platform"].to_string(), + label_macaddr: params["capabilities"]["label_macaddr"].to_string(), + }, + }, + }); + } + + CGWEvent::Empty +} + +async fn cgw_process_jrpc_event(event: &CGWEvent) -> Result<(), String> { + // TODO + if let CGWEvent::Connect(c) = event { + /* + info!( + "Requesting {} to reboot (immediate request)", + c.params.serial + ); + let req = json!({ + "jsonrpc": "2.0", + "method": "reboot", + "params": { + "serial": c.params.serial, + "when": 0 + }, + "id": 1 + }); + info!("Received connect msg {}", c.params.serial); + sender.send(Message::text(req.to_string())).await.ok(); + */ + } + + Ok(()) +} + +// TODO: heavy rework to enum-variant struct-based +async fn cgw_process_jrpc_message(message: Message) -> Result { + //let rpcmsg: CGWMethodConnect = CGWMethodConnect::default(); + //serde_json::from_str(method).unwrap(); + // + let msg = if let Ok(s) = message.into_text() { + s + } else { + return Err("Message to string cast failed".to_string()); + }; + + let map: CGWUcentralJRPCMessage = match serde_json::from_str(&msg) { + Ok(m) => m, + Err(e) => { + error!("Failed to parse input json {e}"); + return Err("Failed to parse input json".to_string()); + } + }; + //.expect("Failed to parse input json"); + + if !map.contains_key("jsonrpc") { + warn!("Received malformed JSONRPC msg"); + return Err("JSONRPC field is missing in message".to_string()); + } + + if map.contains_key("method") { + if !map.contains_key("params") { + warn!("Received JRPC without params."); + return Err("Received JRPC without params".to_string()); + } + + // unwrap can panic + let method = map["method"].as_str().unwrap(); + + let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string()); + + match &event { + CGWEvent::Log(l) => { + debug!( + "Received LOG evt from device {}: {}", + l.params.serial, l.params.log + ); + } + CGWEvent::Connect(c) => { + debug!( + "Received connect evt from device {}: type {}, fw {}", + c.params.serial, c.params.capabilities.platform, c.params.firmware + ); + } + _ => { + warn!("received not yet implemented method {}", method); + return Err(format!("received not yet implemented method {}", method)); + } + }; + + if let Err(e) = cgw_process_jrpc_event(&event).await { + warn!( + "Failed to process jrpc event (unmatched) {}", + method.to_string() + ); + return Err(e); + } + // TODO + } else if map.contains_key("result") { + info!("Processing JSONRPC msg"); + info!("{:?}", map); + return Err("Result handling is not yet implemented".to_string()); + } + + /* + match map.get_mut("jsonrpc") { + Some(value) => info!("Got value {:?}", value), + None => info!("Got no value"), + } + */ + /* + if let CGWMethod::Connect { ref someint, .. } = &rpcmsg { + info!("secondmatch {}", *someint); + return Some(rpcmsg); + } else { + return None; + } + */ + //return Some(CGWMethodConnect::default()); + + Ok(map) +} + +pub struct CGWConnectionProcessor { + cgw_server: Arc, + pub serial: Option, + pub addr: SocketAddr, + pub idx: i64, +} + +impl CGWConnectionProcessor { + pub fn new(server: Arc, conn_idx: i64, addr: SocketAddr) -> Self { + let conn_processor : CGWConnectionProcessor = CGWConnectionProcessor { + cgw_server: server, + serial: None, + addr: addr, + idx: conn_idx, + }; + + conn_processor + } + + pub async fn start(mut self, tls_stream: TlsStream) { + let ws_stream = tokio_tungstenite::accept_async(tls_stream) + .await + .expect("error during the websocket handshake occurred"); + + let (sink, mut stream) = ws_stream.split(); + + // check if we have any pending msgs (we expect connect at this point, protocol-wise) + // TODO: rework to ignore any WS-related frames untill we get a connect message, + // however there's a caveat: we can miss some events logs etc from underlying device + // rework should consider all the options + let msg = tokio::select! { + _val = stream.next() => { + match _val { + Some(m) => m, + None => { + error!("no connect message received from {}, closing connection", self.addr); + return; + } + } + } + // TODO: configurable duration (upon server creation) + _val = sleep(Duration::from_millis(30000)) => { + error!("no message received from {}, closing connection", self.addr); + return; + } + }; + + // we have a next() result, but it still may be undelying io error: check for it + // break connection if we can't work with underlying ws connection (pror err etc) + let message = match msg { + Ok(m) => m, + Err(e) => { + error!("established connection with device, but failed to receive any messages\n{e}"); + return; + } + }; + + let map = match cgw_process_jrpc_message(message).await { + Ok(val) => val, + Err(e) => { + error!("failed to recv connect message from {}, closing connection", self.addr); + return; + } + }; + + let serial = map["params"]["serial"].as_str().unwrap(); + self.serial = Some(serial.to_string()); + + // TODO: we accepted tls stream and split the WS into RX TX part, + // now we have to ASK cgw_connection_server's permission whether + // we can proceed on with this underlying connection. + // cgw_connection_server has an authorative decision whether + // we can proceed. + let (mbox_tx, mut mbox_rx) = unbounded_channel::(); + let msg = CGWConnectionServerReqMsg::AddNewConnection(serial.to_string(), mbox_tx); + self.cgw_server.enqueue_mbox_message_to_cgw_server(msg).await; + + let ack = mbox_rx.recv().await; + if let Some(m) = ack { + match m { + CGWConnectionProcessorReqMsg::AddNewConnectionAck => { + debug!("websocket connection established: {} {}", self.addr, serial); + }, + _ => panic!("Unexpected response from server, expected ACK/NOT ACK)"), + } + } else { + info!("connection server declined connection, websocket connection {} {} cannot be established", + self.addr, serial); + return; + } + + self.process_connection(stream, sink, mbox_rx).await; + } + + async fn process_wss_rx_msg(&self, msg: Result) -> Result { + match msg { + Ok(msg) => { + match msg { + Close(_t) => { + return Ok(CGWConnectionState::ClosedGracefully); + }, + Text(payload) => { + self.cgw_server.enqueue_mbox_message_from_device_to_nb_api_c(self.serial.clone().unwrap(), payload); + return Ok(CGWConnectionState::IsActive); + }, + Ping(_t) => { + return Ok(CGWConnectionState::IsActive); + }, + _ => { + } + } + } + Err(e) => { + match e { + tungstenite::error::Error::AlreadyClosed => { + return Err("Underlying connection's been closed"); + }, + _ => { + } + } + } + } + + Ok(CGWConnectionState::IsActive) + } + + async fn process_sink_mbox_rx_msg(&self,sink: &mut SSink, val: Option) -> Result { + if let Some(msg) = val { + let processor_mac = self.serial.clone().unwrap(); + match msg { + CGWConnectionProcessorReqMsg::AddNewConnectionShouldClose => { + debug!("MBOX_IN: AddNewConnectionShouldClose, processor (mac:{processor_mac}) (ACK OK)"); + return Ok(CGWConnectionState::IsForcedToClose); + }, + CGWConnectionProcessorReqMsg::SinkRequestToDevice(mac, pload) => { + debug!("MBOX_IN: SinkRequestToDevice, processor (mac:{processor_mac}) req for (mac:{mac}) payload:{pload}"); + sink.send(Message::text(pload)).await.ok(); + }, + _ => panic!("Unexpected message received {:?}", msg), + } + } + Ok(CGWConnectionState::IsActive) + } + + async fn process_stale_connection_msg(&self, last_contact: Instant) -> Result { + // TODO: configurable duration (upon server creation) + if Instant::now().duration_since(last_contact) > Duration::from_secs(70) { + warn!("Closing connection {} (idle for too long, stale)", self.addr); + Ok(CGWConnectionState::IsStale) + } else { + Ok(CGWConnectionState::IsActive) + } + } + + async fn process_connection( + self, + mut stream: SStream, + mut sink: SSink, + mut mbox_rx: UnboundedReceiver) { + + #[derive(Debug)] + enum WakeupReason { + Unspecified, + WSSRxMsg(Result), + MboxRx(Option), + Stale, + } + + let mut last_contact = Instant::now(); + let mut poll_wss_first = true; + + // Get underlying wakeup reason and do initial parsion, like: + // - check if WSS stream has a message or an (recv) error + // - check if sinkmbox has a message or an (recv) error + // - check if connection's been stale for X time + // + // TODO: try_next intead of sync .next? could potentially + // skyrocket CPU usage. + loop { + let mut wakeup_reason: WakeupReason = WakeupReason::Unspecified; + + // TODO: refactor + // Round-robin selection of stream to process: + // first, get single message from WSS, then get a single msg from RX MBOX + // It's done to ensure we process WSS and RX MBOX equally with same priority + // Also, we have to make sure we don't sleep-wait for any of the streams to + // make sure we don't cancel futures that are used for stream processing, + // especially TCP stream, which is not cancel-safe + if poll_wss_first { + if let Some(val) = stream.next().now_or_never() { + if let Some(res) = val { + if let Ok(msg) = res { + wakeup_reason = WakeupReason::WSSRxMsg(Ok(msg)); + } else if let Err(msg) = res { + wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(msg)); + } + } else if let None = val { + wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(tungstenite::error::Error::AlreadyClosed)); + } + } else if let Some(val) = mbox_rx.recv().now_or_never() { + wakeup_reason = WakeupReason::MboxRx(val) + } + + poll_wss_first = !poll_wss_first; + } else { + if let Some(val) = mbox_rx.recv().now_or_never() { + wakeup_reason = WakeupReason::MboxRx(val) + } else if let Some(val) = stream.next().now_or_never() { + if let Some(res) = val { + if let Ok(msg) = res { + wakeup_reason = WakeupReason::WSSRxMsg(Ok(msg)); + } else if let Err(msg) = res { + wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(msg)); + } + } else if let None = val { + wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(tungstenite::error::Error::AlreadyClosed)); + } + } + poll_wss_first = !poll_wss_first; + } + + // TODO: somehow workaround the sleeping? + // Both WSS and RX MBOX are empty: chill for a while + if let WakeupReason::Unspecified = wakeup_reason { + sleep(Duration::from_millis(1000)).await; + wakeup_reason = WakeupReason::Stale; + } + + let rc = match wakeup_reason { + WakeupReason::WSSRxMsg(res) => { + last_contact = Instant::now(); + self.process_wss_rx_msg(res).await + }, + WakeupReason::MboxRx(mbox_message) => { + self.process_sink_mbox_rx_msg(&mut sink, mbox_message).await + }, + WakeupReason::Stale => { + self.process_stale_connection_msg(last_contact.clone()).await + }, + _ => { + panic!("Failed to get wakeup reason for {} conn", self.addr); + }, + }; + + match rc { + Err(e) => { + warn!("{}", e); + break; + }, + Ok(state) => { + if let CGWConnectionState::IsActive = state { + continue; + } else if let CGWConnectionState::IsForcedToClose = state { + // Return, because server already closed our mbox tx counterpart (rx), + // hence we don't need to send ConnectionClosed message. Server + // already knows we're closed. + return; + } else if let CGWConnectionState::ClosedGracefully = state { + warn!("Remote client {} closed connection gracefully", self.serial.clone().unwrap()); + break; + } else if let CGWConnectionState::IsStale = state { + warn!("Remote client {} closed due to inactivity", self.serial.clone().unwrap()); + break; + } + }, + } + } + + let mac = self.serial.clone().unwrap(); + let msg = CGWConnectionServerReqMsg::ConnectionClosed(self.serial.unwrap()); + self.cgw_server.enqueue_mbox_message_to_cgw_server(msg).await; + debug!("MBOX_OUT: ConnectionClosed, processor (mac:{})", mac); + } +} diff --git a/src/cgw_connection_server.rs b/src/cgw_connection_server.rs new file mode 100644 index 0000000..ce14bad --- /dev/null +++ b/src/cgw_connection_server.rs @@ -0,0 +1,878 @@ +use crate::{ + AppArgs, +}; + +use crate::cgw_nb_api_listener::{ + CGWNBApiClient, +}; + +use crate::cgw_connection_processor::{ + CGWConnectionProcessor, + CGWConnectionProcessorReqMsg, +}; + +use crate::cgw_db_accessor::{ + CGWDBInfrastructureGroup, +}; + +use crate::cgw_remote_discovery::{ + CGWRemoteDiscovery, +}; + +use tokio::{ + sync::{ + RwLock, + mpsc::{ + UnboundedSender, + UnboundedReceiver, + unbounded_channel, + }, + }, + time::{ + sleep, + Duration, + }, + net::{ + TcpStream, + }, + runtime::{ + Builder, + Runtime, + }, +}; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{ + Arc, + }, +}; + +use std::sync::atomic::{ + AtomicUsize, + Ordering, +}; + +use serde_json::{ + Map, + Value, +}; + +use serde::{ + Deserialize, + Serialize, +}; + +use uuid::{ + Uuid, +}; + +type DeviceSerial = String; +type CGWConnmapType = Arc>>>; + +#[derive(Debug)] +struct CGWConnMap { + map: CGWConnmapType, +} + +impl CGWConnMap { + pub fn new() -> Self { + let hash_map: HashMap> = HashMap::new(); + let map: Arc>>> = Arc::new(RwLock::new(hash_map)); + let connmap = CGWConnMap { + map: map, + }; + connmap + } +} + +type CGWConnectionServerMboxRx = UnboundedReceiver; +type CGWConnectionServerMboxTx = UnboundedSender; +type CGWConnectionServerNBAPIMboxTx = UnboundedSender; +type CGWConnectionServerNBAPIMboxRx = UnboundedReceiver; + +// The following pair used internally by server itself to bind +// Processor's Req/Res +#[derive(Debug)] +pub enum CGWConnectionServerReqMsg { + // Connection-related messages + AddNewConnection(DeviceSerial, UnboundedSender), + ConnectionClosed(DeviceSerial), +} + +#[derive(Debug)] +pub enum CGWConnectionNBAPIReqMsgOrigin { + FromNBAPI, + FromRemoteCGW, +} + +#[derive(Debug)] +pub enum CGWConnectionNBAPIReqMsg { + // Enqueue Key, Request, bool = isMessageRelayed + EnqueueNewMessageFromNBAPIListener(String, String, CGWConnectionNBAPIReqMsgOrigin), +} + +pub struct CGWConnectionServer { + local_cgw_id: i32, + // CGWConnectionServer write into this mailbox, + // and other correspondig Server task Reads RX counterpart + mbox_internal_tx: CGWConnectionServerMboxTx, + + // Object that owns underlying mac:connection map + connmap: CGWConnMap, + + // Runtime that schedules all the WSS-messages related tasks + wss_rx_tx_runtime: Arc, + + // Dedicated runtime (threadpool) for handling internal mbox: + // ACK/nACK connection, handle duplicates (clone/open) etc. + mbox_internal_runtime_handle: Arc, + + // Dedicated runtime (threadpool) for handling NB-API mbox: + // RX NB-API requests, parse, relay (if needed) + mbox_nb_api_runtime_handle: Arc, + + // Dedicated runtime (threadpool) for handling NB-API TX routine: + // TX NB-API requests (if async send is needed) + mbox_nb_api_tx_runtime_handle: Arc, + + // Dedicated runtime (threadpool) for handling (relaying) msgs: + // relay-task is spawned inside it, and the produced stream of + // remote-cgw messages is being relayed inside this context + mbox_relay_msg_runtime_handle: Arc, + + // CGWConnectionServer write into this mailbox, + // and other correspondig NB API client is responsible for doing an RX over + // receive handle counterpart + nb_api_client: Arc, + + // Interface used to access all discovered CGW instances + // (used for relaying non-local CGW requests from NB-API to target CGW) + cgw_remote_discovery: Arc, + + // Handler that helps this object to wrap relayed NB-API messages + // dedicated for this particular local CGW instance + mbox_relayed_messages_handle: CGWConnectionServerNBAPIMboxTx, +} + +/* + * TODO: split into base struct + enum type field + * this requires alot of refactoring in places where msg is used + * this is needed to always have uuid of malformed / discarded msg. + * e.g. + * struct CGWNBApiParsedMsg { + * uuid: Uuid, + * gid: i32, + * type: enum CGWNBApiParsedMsgType, + * } + * + * enum CGWNBApiParsedMsgType { + * InfrastructureGroupCreate, + * ... + * InfrastructureGroupInfraMsg(DeviceSerial, String), + * } + */ +enum CGWNBApiParsedMsg { + // TODO: fix kafka_simulator to provide reserved_size + InfrastructureGroupCreate(Uuid, i32), + InfrastructureGroupDelete(Uuid, i32), + InfrastructureGroupInfraAdd(Uuid, i32, Vec), + InfrastructureGroupInfraDel(Uuid, i32, Vec), + InfrastructureGroupInfraMsg(Uuid, i32, DeviceSerial, String), +} + +impl CGWConnectionServer { + pub async fn new(app_args: &AppArgs) -> Arc { + let wss_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(app_args.wss_t_num) + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("cgw-wss-t-{}", id) + }) + .thread_stack_size(3 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + let internal_mbox_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(1) + .thread_name("cgw-mbox") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + let nb_api_mbox_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(1) + .thread_name("cgw-mbox-nbapi") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + let relay_msg_mbox_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(1) + .thread_name("cgw-relay-mbox-nbapi") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + let nb_api_mbox_tx_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(1) + .thread_name("cgw-mbox-nbapi-tx") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + + let (internal_tx, internal_rx) = unbounded_channel::(); + let (nb_api_tx, nb_api_rx) = unbounded_channel::(); + + // Give NB API client a handle where it can do a TX (CLIENT -> CGW_SERVER) + // RX is handled in internal_mbox of CGW_Server + let nb_api_c = CGWNBApiClient::new(app_args, &nb_api_tx); + + let server = Arc::new(CGWConnectionServer { + local_cgw_id: app_args.cgw_id, + connmap: CGWConnMap::new(), + wss_rx_tx_runtime: wss_runtime_handle, + mbox_internal_runtime_handle: internal_mbox_runtime_handle, + mbox_nb_api_runtime_handle: nb_api_mbox_runtime_handle, + mbox_nb_api_tx_runtime_handle: nb_api_mbox_tx_runtime_handle, + mbox_internal_tx: internal_tx, + nb_api_client: nb_api_c, + cgw_remote_discovery: Arc::new(CGWRemoteDiscovery::new(app_args).await), + mbox_relayed_messages_handle: nb_api_tx, + mbox_relay_msg_runtime_handle: relay_msg_mbox_runtime_handle, + }); + + let server_clone = server.clone(); + // Task for processing mbox_internal_rx, task owns the RX part + server.mbox_internal_runtime_handle.spawn(async move { + server_clone.process_internal_mbox(internal_rx).await; + }); + + let server_clone = server.clone(); + server.mbox_nb_api_runtime_handle.spawn(async move { + server_clone.process_internal_nb_api_mbox(nb_api_rx).await; + }); + + server + } + + pub async fn enqueue_mbox_message_to_cgw_server(&self, req: CGWConnectionServerReqMsg) { + let _ = self.mbox_internal_tx.send(req); + } + + pub fn enqueue_mbox_message_from_device_to_nb_api_c(&self, mac: DeviceSerial, req: String) { + // TODO: device (mac) -> group id matching + let key = String::from("TBD_INFRA_GROUP"); + let nb_api_client_clone = self.nb_api_client.clone(); + tokio::spawn(async move { + let _ = nb_api_client_clone.enqueue_mbox_message_from_cgw_server(key, req).await; + }); + } + + pub fn enqueue_mbox_message_from_cgw_to_nb_api(&self, gid: i32, req: String) { + let nb_api_client_clone = self.nb_api_client.clone(); + self.mbox_nb_api_tx_runtime_handle.spawn(async move { + let _ = nb_api_client_clone.enqueue_mbox_message_from_cgw_server(gid.to_string(), req).await; + }); + } + + pub async fn enqueue_mbox_relayed_message_to_cgw_server(&self, key: String, req: String) { + let msg = CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, req, CGWConnectionNBAPIReqMsgOrigin::FromRemoteCGW); + let _ = self.mbox_relayed_messages_handle.send(msg); + } + + fn parse_nbapi_msg(&self, pload: &String) -> Option { + #[derive(Debug, Serialize, Deserialize)] + struct InfraGroupCreate { + r#type: String, + infra_group_id: String, + infra_name: String, + infra_shard_id: i32, + uuid: Uuid, + } + #[derive(Debug, Serialize, Deserialize)] + struct InfraGroupDelete { + r#type: String, + infra_group_id: String, + uuid: Uuid, + } + + #[derive(Debug, Serialize, Deserialize)] + struct InfraGroupInfraAdd { + r#type: String, + infra_group_id: String, + infra_group_infra_devices: Vec, + uuid: Uuid, + } + + #[derive(Debug, Serialize, Deserialize)] + struct InfraGroupInfraDel { + r#type: String, + infra_group_id: String, + infra_group_infra_devices: Vec, + uuid: Uuid, + } + + #[derive(Debug, Serialize, Deserialize)] + struct InfraGroupMsgJSON { + r#type: String, + infra_group_id: String, + mac: String, + msg: Map, + uuid: Uuid, + } + + let rc = serde_json::from_str(pload); + if let Err(e) = rc { + error!("{e}\n{pload}"); + return None; + } + + let map: Map = rc.unwrap(); + + let rc = map.get(&String::from("type")); + if let None = rc { + error!("No msg_type found in\n{pload}"); + return None; + } + let rc = rc.unwrap(); + + let msg_type = rc.as_str().unwrap(); + let rc = map.get(&String::from("infra_group_id")); + if let None = rc { + error!("No infra_group_id found in\n{pload}"); + return None; + } + let rc = rc.unwrap(); + let group_id: i32 = rc.as_str().unwrap().parse().unwrap(); + + //debug!("Got msg {msg_type}, infra {group_id}"); + + match msg_type { + "infrastructure_group_create" => { + let json_msg: InfraGroupCreate = serde_json::from_str(&pload).unwrap(); + //debug!("{:?}", json_msg); + return Some(CGWNBApiParsedMsg::InfrastructureGroupCreate(json_msg.uuid, group_id)); + }, + "infrastructure_group_delete" => { + let json_msg: InfraGroupDelete = serde_json::from_str(&pload).unwrap(); + //debug!("{:?}", json_msg); + return Some(CGWNBApiParsedMsg::InfrastructureGroupDelete(json_msg.uuid, group_id)); + }, + "infrastructure_group_device_add" => { + let json_msg: InfraGroupInfraAdd = serde_json::from_str(&pload).unwrap(); + //debug!("{:?}", json_msg); + return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraAdd(json_msg.uuid, group_id, json_msg.infra_group_infra_devices)); + } + "infrastructure_group_device_del" => { + let json_msg: InfraGroupInfraDel = serde_json::from_str(&pload).unwrap(); + //debug!("{:?}", json_msg); + return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraDel(json_msg.uuid, group_id, json_msg.infra_group_infra_devices)); + } + "infrastructure_group_device_message" => { + let json_msg: InfraGroupMsgJSON = serde_json::from_str(&pload).unwrap(); + //debug!("{:?}", json_msg); + return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraMsg(json_msg.uuid, group_id, json_msg.mac, serde_json::to_string(&json_msg.msg).unwrap())); + }, + &_ => { + debug!("Unknown type {msg_type} received"); + } + } + + None + } + + async fn process_internal_nb_api_mbox(self: Arc, mut rx_mbox: CGWConnectionServerNBAPIMboxRx) { + debug!("process_nb_api_mbox entry"); + + let buf_capacity = 2000; + let mut buf: Vec = Vec::with_capacity(buf_capacity); + let mut num_of_msg_read = 0; + // As of now, expect at max 100 CGWS remote instances without buffers realloc + // This only means that original capacity of all buffers is allocated to <100>, + // it can still increase on demand or need automatically (upon insert, push_back etc) + let cgw_buf_prealloc_size = 100; + + let mut local_parsed_cgw_msg_buf: Vec = Vec::with_capacity(buf_capacity); + + loop { + if num_of_msg_read < buf_capacity { + // Try to recv_many, but don't sleep too much + // in case if no messaged pending and we have + // TODO: rework to pull model (pull on demand), + // compared to curr impl: push model (nb api listener forcefully + // pushed all fetched data from kafka). + // Currently recv_many may sleep if previous read >= 1, + // but no new messages pending + // + // It's also possible that this logic staggers the processing, + // in case when every new message is received <=9 ms for example: + // Single message received, waiting for new up to 10 ms. + // New received on 9th ms. Repeat. + // And this could repeat up untill buffer is full, or no new messages + // appear on the 10ms scale. + // Highly unlikly scenario, but still possible. + let rd_num = tokio::select! { + v = rx_mbox.recv_many(&mut buf, buf_capacity - num_of_msg_read) => { + v + } + v = sleep(Duration::from_millis(10)) => { + 0 + } + }; + num_of_msg_read += rd_num; + + // We read some messages, try to continue and read more + // If none read - break from recv, process all buffers that've + // been filled-up so far (both local and remote). + // Upon done - repeat. + if rd_num >= 1 { + continue; + } else { + if num_of_msg_read == 0 { + continue; + } + } + } + + debug!("Received {num_of_msg_read} messages from NB API, processing..."); + + // We rely on this map only for a single iteration of received messages: + // say, we receive 10 messages but 20 in queue, this means that gid->cgw_id + // cache is clear at first, the filled up when processing first 10 messages, + // the clear/reassigned again for next 10 msgs (10->20). + // This is done to ensure that we don't fallback for redis too much, + // but still somewhat fully rely on it. + // + self.cgw_remote_discovery.sync_gid_to_cgw_map().await; + + local_parsed_cgw_msg_buf.clear(); + + // TODO: rework to avoid re-allocating these buffers on each loop iteration + // (get mut slice of vec / clear when done?) + let mut relayed_cgw_msg_buf: Vec<(i32, CGWConnectionNBAPIReqMsg)> = Vec::with_capacity(num_of_msg_read + 1); + let mut local_cgw_msg_buf: Vec = Vec::with_capacity(num_of_msg_read + 1); + + while ! buf.is_empty() { + let msg = buf.remove(0); + + if let CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin) = msg { + let gid_numeric = match key.parse::() { + Err(e) => { + warn!("Invalid KEY received from KAFKA bus message, ignoring\n{e}"); + continue; + }, + Ok(v) => v + }; + + let parsed_msg = match self.parse_nbapi_msg(&payload) { + Some(val) => val, + None => { + warn!("Failed to parse recv msg with key {key}, discarded"); + continue; + } + }; + + // The one shard that received add/del is responsible for + // handling it at place. + // Any other msg is either relayed / handled locally later. + // The reason for this is following: current shard is responsible + // for assignment of GID to shard, thus it has to make + // assignment as soon as possible to deduce relaying action in + // the following message pool that is being handled. + // Same for delete. + if let CGWNBApiParsedMsg::InfrastructureGroupCreate(uuid, gid) = parsed_msg { + // DB stuff - create group for remote shards to be aware of change + let group = CGWDBInfrastructureGroup { + id: gid, + reserved_size: 1000i32, + actual_size: 0i32, + }; + match self.cgw_remote_discovery.create_infra_group(&group).await { + Ok(dst_cgw_id) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Group has been created successfully gid {gid}")); + }, + Err(e) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to create new group (duplicate create?), gid {gid}")); + warn!("Create group gid {gid} received, but it already exists"); + } + } + // This type of msg is handled in place, not added to buf + // for later processing. + continue; + } else if let CGWNBApiParsedMsg::InfrastructureGroupDelete(uuid, gid) = parsed_msg { + match self.cgw_remote_discovery.destroy_infra_group(gid).await { + Ok(()) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Group has been destroyed successfully gid {gid}, uuid {uuid}")); + }, + Err(e) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to destroy group (doesn't exist?), gid {gid}, uuid {uuid}")); + warn!("Destroy group gid {gid} received, but it does not exist"); + } + } + // This type of msg is handled in place, not added to buf + // for later processing. + continue; + } + + // We received NB API msg, check origin: + // If it's a relayed message, we must not relay it further + // If msg originated from Kafka originally, it's safe to relay it (if needed) + if let CGWConnectionNBAPIReqMsgOrigin::FromRemoteCGW = origin { + local_parsed_cgw_msg_buf.push(parsed_msg); + continue; + } + + match self.cgw_remote_discovery.get_infra_group_owner_id(key.parse::().unwrap()).await { + Some(dst_cgw_id) => { + if dst_cgw_id == self.local_cgw_id { + local_cgw_msg_buf.push(CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin)); + } else { + relayed_cgw_msg_buf.push((dst_cgw_id, CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin))); + } + }, + None => { + warn!("Received msg for gid {gid_numeric}, while this group is unassigned to any of CGWs: rejecting"); + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid_numeric, + format!("Received message for unknown group {gid_numeric} - unassigned?")); + } + } + } + } + + let discovery_clone = self.cgw_remote_discovery.clone(); + let self_clone = self.clone(); + + // Future to Handle (relay) messages for remote CGW + let relay_task_hdl = self.mbox_relay_msg_runtime_handle.spawn(async move { + let mut remote_cgws_map: HashMap)> = HashMap::with_capacity(cgw_buf_prealloc_size); + + while ! relayed_cgw_msg_buf.is_empty() { + let msg = relayed_cgw_msg_buf.remove(0); + if let (dst_cgw_id, CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, _origin)) = msg { + debug!("Received MSG for remote CGW k:{}, local id {} relaying msg to remote...", key, self_clone.local_cgw_id); + if let Some(v) = remote_cgws_map.get_mut(&key) { + v.1.push((key, payload)); + } else { + let mut tmp_vec: Vec<(String, String)> = Vec::with_capacity(num_of_msg_read); + tmp_vec.push((key.clone(), payload)); + remote_cgws_map.insert(key, (dst_cgw_id, tmp_vec)); + } + } + } + + for value in remote_cgws_map.into_values() { + let discovery_clone = discovery_clone.clone(); + let cgw_id = value.0; + let msg_stream = value.1; + let self_clone = self_clone.clone(); + tokio::spawn(async move { + if let Err(()) = discovery_clone.relay_request_stream_to_remote_cgw(cgw_id, msg_stream).await { + self_clone.enqueue_mbox_message_from_cgw_to_nb_api( + -1, + format!("Failed to relay MSG stream to remote CGW{cgw_id}, UUIDs: not implemented (TODO)")); + } + }); + } + }); + + // Handle messages for local CGW + // Parse all messages first, then process + // TODO: try to parallelize at least parsing of msg: + // iterate each msg, get index, spawn task that would + // write indexed parsed msg into output parsed msg buf. + let connmap_clone = self.connmap.map.clone(); + while ! local_cgw_msg_buf.is_empty() { + let msg = local_cgw_msg_buf.remove(0); + if let CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, _origin) = msg { + let gid_numeric: i32 = key.parse::().unwrap(); + debug!("Received message for local CGW k:{key}, local id {}", self.local_cgw_id); + let msg = self.parse_nbapi_msg(&payload); + if let None = msg { + error!("Failed to parse msg from NBAPI (malformed?)"); + continue; + } + + match msg.unwrap() { + CGWNBApiParsedMsg::InfrastructureGroupInfraAdd(uuid, gid, mac_list) => { + if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await { + warn!("Unexpected: tried to add infra list to nonexisting group (gid {gid}, uuid {uuid}"); + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to insert MACs from infra list, gid {gid}, uuid {uuid}: group does not exist.")); + } + + match self.cgw_remote_discovery.create_ifras_list(gid, mac_list).await { + Ok(()) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Infra list has been created successfully gid {gid}, uuid {uuid}")); + }, + Err(macs) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to insert few MACs from infra list, gid {gid}, uuid {uuid}; List of failed MACs:{}", + macs.iter().map(|x| x.to_string() + ",").collect::())); + warn!("Failed to create few MACs from infras list (partial create)"); + continue; + } + } + }, + CGWNBApiParsedMsg::InfrastructureGroupInfraDel(uuid, gid, mac_list) => { + if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await { + warn!("Unexpected: tried to delete infra list from nonexisting group (gid {gid}, uuid {uuid}"); + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to delete MACs from infra list, gid {gid}, uuid {uuid}: group does not exist.")); + } + + match self.cgw_remote_discovery.destroy_ifras_list(gid, mac_list).await { + Ok(()) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Infra list has been destroyed successfully gid {gid}, uuid {uuid}")); + }, + Err(macs) => { + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to destroy few MACs from infra list (not created?), gid {gid}, uuid {uuid}; List of failed MACs:{}", + macs.iter().map(|x| x.to_string() + ",").collect::())); + warn!("Failed to destroy few MACs from infras list (partial delete)"); + continue; + } + } + }, + CGWNBApiParsedMsg::InfrastructureGroupInfraMsg(uuid, gid, mac, msg) => { + if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await { + warn!("Unexpected: tried to sink down msg to device of nonexisting group (gid {gid}, uuid {uuid}"); + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to sink down msg to device of nonexisting group, gid {gid}, uuid {uuid}: group does not exist.")); + } + + debug!("Sending msg to device {mac}"); + let rd_lock = connmap_clone.read().await; + let rc = rd_lock.get(&mac); + if let None = rc { + error!("Cannot find suitable connection for {mac}"); + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to send msg (device not connected?), gid {gid}, uuid {uuid}")); + continue; + } + + let proc_mbox_tx = rc.unwrap(); + if let Err(e) = proc_mbox_tx.send(CGWConnectionProcessorReqMsg::SinkRequestToDevice(mac, msg)) { + error!("Failed to send message to remote device (msg uuid({uuid}))"); + self.enqueue_mbox_message_from_cgw_to_nb_api( + gid, + format!("Failed to send msg, gid {gid}, uuid {uuid}")); + } + }, + _ => { + debug!("Received unimplemented/unexpected group create/del msg, ignoring"); + } + } + } + } + + // Do not proceed parsing local / remote msgs untill previous relaying has been + // finished + tokio::join!(relay_task_hdl); + + buf.clear(); + num_of_msg_read = 0; + } + panic!("RX or TX counterpart of nb_api channel part destroyed, while processing task is still active"); + } + + async fn process_internal_mbox(self: Arc, mut rx_mbox: CGWConnectionServerMboxRx) { + debug!("process_internal_mbox entry"); + + let buf_capacity = 1000; + let mut buf: Vec = Vec::with_capacity(buf_capacity); + let mut num_of_msg_read = 0; + + loop { + if num_of_msg_read < buf_capacity { + // Try to recv_many, but don't sleep too much + // in case if no messaged pending and we have + // TODO: rework? + // Currently recv_many may sleep if previous read >= 1, + // but no new messages pending + let rd_num = tokio::select! { + v = rx_mbox.recv_many(&mut buf, buf_capacity - num_of_msg_read) => { + v + } + v = sleep(Duration::from_millis(10)) => { + 0 + } + }; + num_of_msg_read += rd_num; + + // We read some messages, try to continue and read more + // If none read - break from recv, process all buffers that've + // been filled-up so far (both local and remote). + // Upon done - repeat. + if rd_num >= 1 { + if num_of_msg_read < 100 { + continue; + } + } else { + if num_of_msg_read == 0 { + continue; + } + } + } + + let mut connmap_w_lock = self.connmap.map.write().await; + + while ! buf.is_empty() { + let msg = buf.remove(0); + + if let CGWConnectionServerReqMsg::AddNewConnection(serial, conn_processor_mbox_tx) = msg { + // if connection is unique: simply insert new conn + // + // if duplicate exists: notify server about such incident. + // it's up to server to notify underlying task that it should + // drop the connection. + // from now on simply insert new connection into hashmap and proceed on + // processing it. + let serial_clone: DeviceSerial = serial.clone(); + if let Some(c) = connmap_w_lock.remove(&serial_clone) { + tokio::spawn(async move { + warn!("Duplicate connection (mac:{}) detected, closing OLD connection in favor of NEW", serial_clone); + let msg: CGWConnectionProcessorReqMsg = CGWConnectionProcessorReqMsg::AddNewConnectionShouldClose; + c.send(msg).unwrap(); + }); + } + + // clone a sender handle, as we still have to send ACK back using underlying + // tx mbox handle + let conn_processor_mbox_tx_clone = conn_processor_mbox_tx.clone(); + + info!("connmap: connection with {} established, new num_of_connections:{}", serial, connmap_w_lock.len() + 1); + connmap_w_lock.insert(serial, conn_processor_mbox_tx); + + tokio::spawn(async move { + let msg: CGWConnectionProcessorReqMsg = CGWConnectionProcessorReqMsg::AddNewConnectionAck; + conn_processor_mbox_tx_clone.send(msg).unwrap(); + }); + } else if let CGWConnectionServerReqMsg::ConnectionClosed(serial) = msg { + info!("connmap: removed {} serial from connmap, new num_of_connections:{}", serial, connmap_w_lock.len() - 1); + connmap_w_lock.remove(&serial); + } + } + + buf.clear(); + num_of_msg_read = 0; + } + + panic!("RX or TX counterpart of mbox_internal channel part destroyed, while processing task is still active"); + } + + pub async fn ack_connection(self: Arc, socket: TcpStream, tls_acceptor: tokio_native_tls::TlsAcceptor, addr: SocketAddr, conn_idx: i64) { + // Only ACK connection. We will either drop it or accept it once processor starts + // (we'll handle it via "mailbox" notify handle in process_internal_mbox) + let server_clone = self.clone(); + + self.wss_rx_tx_runtime.spawn(async move { + // Accept the TLS connection. + let tls_stream = match tls_acceptor.accept(socket).await { + Ok(a) => a, + Err(e) => { + warn!("Err {e}"); + return; + } + }; + let conn_processor = CGWConnectionProcessor::new(server_clone, conn_idx, addr); + conn_processor.start(tls_stream).await; + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn get_connect_json_msg() -> &'static str { + r#" + { + "jsonrpc": "2.0", + "method": "connect", + "params": { + "serial": "00000000ca4b", + "firmware": "SONiC-OS-4.1.0_vs_daily_221213_1931_422-campus", + "uuid": 1, + "capabilities": { + "compatible": "+++x86_64-kvm_x86_64-r0", + "model": "DellEMC-S5248f-P-25G-DPB", + "platform": "switch", + "label_macaddr": "00:00:00:00:ca:4b" + } + } + }"# + } + + fn get_log_json_msg() -> &'static str { + r#" + { + "jsonrpc": "2.0", + "method": "log", + "params": { + "serial": "00000000ca4b", + "log": "uc-client: connection error: Unable to connect", + "severity": 3 + } + }"# + } + + #[test] + fn can_parse_connect_event() { + let msg = get_connect_json_msg(); + + let map: Map = + serde_json::from_str(msg).expect("Failed to parse input json"); + let method = map["method"].as_str().unwrap(); + let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string()); + + match event { + CGWEvent::Connect(_) => { + assert!(true); + } + _ => { + assert!(false, "Expected event to be of type"); + } + } + } + + #[test] + fn can_parse_log_event() { + let msg = get_log_json_msg(); + + let map: Map = + serde_json::from_str(msg).expect("Failed to parse input json"); + let method = map["method"].as_str().unwrap(); + let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string()); + + match event { + CGWEvent::Log(_) => { + assert!(true); + } + _ => { + assert!(false, "Expected event to be of type"); + } + } + } +} diff --git a/src/cgw_db_accessor.rs b/src/cgw_db_accessor.rs new file mode 100644 index 0000000..33c8601 --- /dev/null +++ b/src/cgw_db_accessor.rs @@ -0,0 +1,205 @@ +use crate::{ + AppArgs, +}; + +use eui48::{ + MacAddress, +}; + +use tokio_postgres::{ + Client, + NoTls, + row::{ + Row, + }, +}; + +#[derive(Clone)] +pub struct CGWDBInfra { + pub mac: MacAddress, + pub infra_group_id: i32, +} + +#[derive(Clone)] +pub struct CGWDBInfrastructureGroup { + pub id: i32, + pub reserved_size: i32, + pub actual_size: i32, +} + +impl From for CGWDBInfra { + fn from(row: Row) -> Self { + let serial: MacAddress = row.get("mac"); + let gid: i32 = row.get("infra_group_id"); + Self { + mac: serial, + infra_group_id: gid, + } + } +} + +impl From for CGWDBInfrastructureGroup { + fn from(row: Row) -> Self { + let infra_id: i32 = row.get("id"); + let res_size: i32 = row.get("reserved_size"); + let act_size: i32 = row.get("actual_size"); + Self { + id: infra_id, + reserved_size: res_size, + actual_size: act_size, + } + } +} + +pub struct CGWDBAccessor { + cl: Client, +} + +impl CGWDBAccessor { + pub async fn new(app_args: &AppArgs) -> Self { + let conn_str = format!("host={host} port={port} user={user} dbname={db} password={pass}", + host = app_args.db_ip, + port = app_args.db_port, + user = app_args.db_username, + db = app_args.db_name, + pass = app_args.db_password); + debug!("Trying to connect to remote db ({}:{})...", + app_args.db_ip.to_string(), + app_args.db_port.to_string()); + debug!("Conn args {conn_str}"); + let (client, connection) = + tokio_postgres::connect(&conn_str, NoTls).await.unwrap(); + + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + info!("Connected to remote DB"); + + CGWDBAccessor { + cl: client, + } + } + + /* + * INFRA_GROUP db API uses the following table decl + * TODO: id = int, not varchar; requires kafka simulator changes + CREATE TABLE infrastructure_groups ( + id VARCHAR (50) PRIMARY KEY, + reserved_size INT, + actual_size INT + ); + * + */ + + pub async fn insert_new_infra_group(&self, g: &CGWDBInfrastructureGroup) -> Result<(), &'static str> { + let q = self.cl.prepare("INSERT INTO infrastructure_groups (id, reserved_size, actual_size) VALUES ($1, $2, $3)").await.unwrap(); + let res = self.cl.execute(&q, &[&g.id, &g.reserved_size, &g.actual_size]).await; + + match res { + Ok(n) => return Ok(()), + Err(e) => { + error!("Failed to insert a new infra group {}: {:?}", g.id, e.to_string()); + return Err("Insert new infra group failed"); + } + } + } + + pub async fn delete_infra_group(&self, gid: i32) -> Result<(), &'static str> { + // TODO: query-base approach instead of static string + let req = self.cl.prepare("DELETE FROM infrastructure_groups WHERE id = $1").await.unwrap(); + let res = self.cl.execute(&req, &[&gid]).await; + + match res { + Ok(n) => { + if n > 0 { + return Ok(()); + } else { + return Err("Failed to delete group from DB: gid does not exist"); + } + }, + Err(e) => { + error!("Failed to delete an infra group {gid}: {:?}", e.to_string()); + return Err("Delete infra group failed"); + } + } + } + + pub async fn get_all_infra_groups(&self) -> Option> { + let mut list: Vec = Vec::with_capacity(1000); + + let res = self.cl.query("SELECT * from infrastructure_groups", &[]).await; + + match res { + Ok(r) => { + for x in r { + let infra_group = CGWDBInfrastructureGroup::from(x); + list.push(infra_group); + } + return Some(list); + } + Err(e) => { + return None; + } + } + } + + pub async fn get_infra_group(&self, gid: i32) -> Option { + let q = self.cl.prepare("SELECT * from infrastructure_groups WHERE id = $1").await.unwrap(); + let row = self.cl.query_one(&q, &[&gid]).await; + + match row { + Ok(r) => return Some(CGWDBInfrastructureGroup::from(r)), + Err(e) => { + return None; + } + } + } + + /* + * INFRA db API uses the following table decl + CREATE TABLE infras ( + mac MACADDR PRIMARY KEY, + infra_group_id INT, + FOREIGN KEY(infra_group_id) REFERENCES infrastructure_groups(id) ON DELETE CASCADE + ); + */ + + pub async fn insert_new_infra(&self, infra: &CGWDBInfra) -> Result<(), &'static str> { + let q = self.cl.prepare("INSERT INTO infras (mac, infra_group_id) VALUES ($1, $2)").await.unwrap(); + let res = self.cl.execute( + &q, + &[&infra.mac, &infra.infra_group_id]).await; + + match res { + Ok(n) => return Ok(()), + Err(e) => { + error!("Failed to insert a new infra: {:?}", e.to_string()); + return Err("Insert new infra failed"); + } + } + } + + pub async fn delete_infra(&self, serial: MacAddress) -> Result<(), &'static str> { + let q = self.cl.prepare("DELETE FROM infras WHERE mac = $1").await.unwrap(); + let res = self.cl.execute( + &q, + &[&serial]).await; + + match res { + Ok(n) => { + if n > 0 { + return Ok(()); + } else { + return Err("Failed to delete infra from DB: MAC does not exist"); + } + }, + Err(e) => { + error!("Failed to delete infra: {:?}", e.to_string()); + return Err("Delete infra failed"); + } + } + } +} diff --git a/src/cgw_nb_api_listener.rs b/src/cgw_nb_api_listener.rs new file mode 100644 index 0000000..9f5e5cd --- /dev/null +++ b/src/cgw_nb_api_listener.rs @@ -0,0 +1,260 @@ +use crate::{ + AppArgs, +}; + +use crate::cgw_connection_server::{ + CGWConnectionNBAPIReqMsg, + CGWConnectionNBAPIReqMsgOrigin, +}; + +use std::{ + sync::{ + Arc, + }, +}; +use tokio::{ + sync::{ + mpsc::{ + UnboundedSender, + }, + }, + time::{ + Duration, + }, + runtime::{ + Builder, + Runtime, + }, +}; +use futures::stream::{TryStreamExt}; +use rdkafka::client::ClientContext; +use rdkafka::config::{ClientConfig, RDKafkaLogLevel}; +use rdkafka::{ + consumer::{ + Consumer, + ConsumerContext, + Rebalance, + stream_consumer::{ + StreamConsumer, + }, + }, + producer::{ + FutureProducer, + FutureRecord, + }, +}; +use rdkafka::error::KafkaResult; +use rdkafka::message::{Message}; +use rdkafka::topic_partition_list::TopicPartitionList; + +type CGWConnectionServerMboxTx = UnboundedSender; +type CGWCNCConsumerType = StreamConsumer; +type CGWCNCProducerType = FutureProducer; + +struct CustomContext; +impl ClientContext for CustomContext {} + +impl ConsumerContext for CustomContext { + fn pre_rebalance(&self, rebalance: &Rebalance) { + let mut part_list = String::new(); + if let rdkafka::consumer::Rebalance::Assign(partitions) = rebalance { + for x in partitions.elements() { + part_list += &(x.partition().to_string() + " "); + } + debug!("pre_rebalance callback, assigned partition(s): {part_list}"); + } + + part_list.clear(); + + if let rdkafka::consumer::Rebalance::Revoke(partitions) = rebalance { + for x in partitions.elements() { + part_list += &(x.partition().to_string() + " "); + } + debug!("pre_rebalance callback, revoked partition(s): {part_list}"); + } + } + + fn post_rebalance(&self, rebalance: &Rebalance) { + let mut part_list = String::new(); + + if let rdkafka::consumer::Rebalance::Assign(partitions) = rebalance { + for x in partitions.elements() { + part_list += &(x.partition().to_string() + " "); + } + debug!("post_rebalance callback, assigned partition(s): {part_list}"); + } + + part_list.clear(); + + if let rdkafka::consumer::Rebalance::Revoke(partitions) = rebalance { + for x in partitions.elements() { + part_list += &(x.partition().to_string() + " "); + } + debug!("post_rebalance callback, revoked partition(s): {part_list}"); + } + } + + fn commit_callback(&self, result: KafkaResult<()>, _offsets: &TopicPartitionList) { + let mut part_list = String::new(); + for x in _offsets.elements() { + part_list += &(x.partition().to_string() + " "); + } + debug!("commit_callback callback, partition(s): {part_list}"); + debug!("Consumer callback: commited offset"); + } +} + +static GROUP_ID: &'static str = "CGW"; +const CONSUMER_TOPICS: [&'static str; 1] = ["CnC"]; +const PRODUCER_TOPICS: &'static str = "CnC_Res"; + +struct CGWCNCProducer { + p: CGWCNCProducerType, +} + +struct CGWCNCConsumer { + c: CGWCNCConsumerType, +} + +impl CGWCNCConsumer { + pub fn new(app_args: &AppArgs) -> Self { + let consum: CGWCNCConsumerType = Self::create_consumer(app_args); + CGWCNCConsumer { + c: consum, + } + } + + fn create_consumer(app_args: &AppArgs) -> CGWCNCConsumerType { + let context = CustomContext; + + debug!("Trying to connect to kafka broker ({}:{})...", + app_args.kafka_ip.to_string(), + app_args.kafka_port.to_string()); + + let consumer: CGWCNCConsumerType = ClientConfig::new() + .set("group.id", GROUP_ID) + .set("client.id", GROUP_ID.to_string() + &app_args.cgw_id.to_string()) + .set("group.instance.id", app_args.cgw_id.to_string()) + .set("bootstrap.servers", app_args.kafka_ip.to_string() + ":" + &app_args.kafka_port.to_string()) + .set("enable.partition.eof", "false") + .set("session.timeout.ms", "6000") + .set("enable.auto.commit", "true") + //.set("statistics.interval.ms", "30000") + //.set("auto.offset.reset", "smallest") + .set_log_level(RDKafkaLogLevel::Debug) + .create_with_context(context) + .expect("Consumer creation failed"); + + consumer + .subscribe(&CONSUMER_TOPICS) + .expect("Failed to subscribe to {CONSUMER_TOPICS} topics"); + + info!("Connected to kafka broker"); + + consumer + } +} + +impl CGWCNCProducer { + pub fn new() -> Self { + let prod: CGWCNCProducerType = Self::create_producer(); + CGWCNCProducer { + p: prod, + } + } + + fn create_producer() -> CGWCNCProducerType { + let producer: FutureProducer = ClientConfig::new() + .set("bootstrap.servers", "172.20.10.136:9092") + .set("message.timeout.ms", "5000") + .create() + .expect("Producer creation error"); + + producer + } +} + +pub struct CGWNBApiClient { + working_runtime_handle: Runtime, + cgw_server_tx_mbox: CGWConnectionServerMboxTx, + prod: CGWCNCProducer, + // TBD: stplit different implementators through a defined trait, + // that implements async R W operations? +} + +impl CGWNBApiClient { + pub fn new(app_args: &AppArgs, cgw_tx: &CGWConnectionServerMboxTx) -> Arc { + let working_runtime_h = Builder::new_multi_thread() + .worker_threads(1) + .thread_name("cgw-nb-api-l") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap(); + let cl = Arc::new(CGWNBApiClient { + working_runtime_handle: working_runtime_h, + cgw_server_tx_mbox: cgw_tx.clone(), + prod: CGWCNCProducer::new(), + }); + + let cl_clone = cl.clone(); + let consumer: CGWCNCConsumer = CGWCNCConsumer::new(app_args); + cl.working_runtime_handle.spawn(async move { + loop { + let cl_clone = cl_clone.clone(); + let stream_processor = consumer.c.stream().try_for_each(|borrowed_message| { + let cl_clone = cl_clone.clone(); + async move { + // Process each message + // Borrowed messages can't outlive the consumer they are received from, so they need to + // be owned in order to be sent to a separate thread. + //record_owned_message_receipt(&owned_message).await; + let owned = borrowed_message.detach(); + + let key = match owned.key_view::() { + None => "", + Some(Ok(s)) => s, + Some(Err(e)) => { + warn!("Error while deserializing message payload: {:?}", e); + "" + } + }; + + let payload = match owned.payload_view::() { + None => "", + Some(Ok(s)) => s, + Some(Err(e)) => { + warn!("Error while deserializing message payload: {:?}", e); + "" + } + }; + cl_clone.enqueue_mbox_message_to_cgw_server(key.to_string(), payload.to_string()).await; + Ok(()) + } + }); + stream_processor.await.expect("stream processing failed"); + } + }); + + cl + } + + pub async fn enqueue_mbox_message_from_cgw_server(&self, key: String, payload: String) { + let produce_future = self.prod.p.send( + FutureRecord::to(&PRODUCER_TOPICS) + .key(&key) + .payload(&payload), + Duration::from_secs(0), + ); + match produce_future.await { + Err((e, _)) => println!("Error: {:?}", e), + _ => {} + } + } + + async fn enqueue_mbox_message_to_cgw_server(&self, key: String, payload: String) { + debug!("MBOX_OUT: EnqueueNewMessageFromNBAPIListener, k:{key}"); + let msg = CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, CGWConnectionNBAPIReqMsgOrigin::FromNBAPI); + let _ = self.cgw_server_tx_mbox.send(msg); + } +} diff --git a/src/cgw_remote_client.rs b/src/cgw_remote_client.rs new file mode 100644 index 0000000..9bd89a9 --- /dev/null +++ b/src/cgw_remote_client.rs @@ -0,0 +1,75 @@ +use crate::{ + AppArgs, +}; + +pub mod cgw_remote { + tonic::include_proto!("cgw.remote"); +} + +use tonic::{ + transport::{ + Uri, + channel::{ + Channel, + }, + }, +}; + +use cgw_remote::{ + EnqueueRequest, + remote_client::{ + RemoteClient, + }, +}; + +use tokio::{ + time::{ + Duration, + }, +}; + +#[derive(Clone)] +pub struct CGWRemoteClient { + remote_client: RemoteClient, +} + +impl CGWRemoteClient { + pub fn new(hostname: String) -> Self { + let uri = Uri::from_maybe_shared(hostname).unwrap(); + let r_channel = Channel::builder(uri) + .timeout(Duration::from_secs(20)) + .connect_timeout(Duration::from_secs(20)) + .connect_lazy(); + + let client = RemoteClient::new(r_channel); + + CGWRemoteClient { + remote_client: client, + } + } + + pub async fn relay_request_stream(&self, stream: Vec<(String, String)>) -> Result<(), ()> { + let mut cl_clone = self.remote_client.clone(); + + let mut messages: Vec = vec![]; + let mut it = stream.into_iter(); + + while let Some(x) = it.next() { + messages.push(EnqueueRequest { + key: x.0, + req: x.1, + }); + } + + let rq = tonic::Request::new(tokio_stream::iter(messages.clone())); + match cl_clone.enqueue_nbapi_request_stream(rq).await { + Err(e) => { + error!("Failed to relay req: {:?}", e); + Err(()) + } + Ok(r) => { + Ok(()) + } + } + } +} diff --git a/src/cgw_remote_discovery.rs b/src/cgw_remote_discovery.rs new file mode 100644 index 0000000..8e78e80 --- /dev/null +++ b/src/cgw_remote_discovery.rs @@ -0,0 +1,559 @@ +use crate::{ + AppArgs, + cgw_remote_client::{ + CGWRemoteClient, + }, + cgw_db_accessor::{ + CGWDBAccessor, + CGWDBInfrastructureGroup, + CGWDBInfra, + }, +}; + +use std::{ + collections::{ + HashMap, + }, + net::{ + Ipv4Addr, + IpAddr, + SocketAddr, + }, + sync::{ + Arc, + }, +}; + +use redis_async::{ + resp_array, +}; + +use eui48::{ + MacAddress, +}; + +use tokio::{ + sync::{ + RwLock, + } +}; + +// Used in remote lookup +static REDIS_KEY_SHARD_ID_PREFIX:&'static str = "shard_id_"; +static REDIS_KEY_SHARD_ID_FIELDS_NUM: usize = 12; +static REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM:&'static str = "assigned_groups_num"; + +// Used in group assign / reassign +static REDIS_KEY_GID:&'static str = "group_id_"; +static REDIS_KEY_GID_VALUE_GID:&'static str = "gid"; +static REDIS_KEY_GID_VALUE_SHARD_ID:&'static str = "shard_id"; + +#[derive(Clone, Debug)] +pub struct CGWREDISDBShard { + id: i32, + server_ip: IpAddr, + server_port: u16, + assigned_groups_num: i32, + capacity: i32, + threshold: i32, +} + +impl From> for CGWREDISDBShard { + fn from(values: Vec) -> Self { + assert!(values.len() >= REDIS_KEY_SHARD_ID_FIELDS_NUM, + "Unexpected size of parsed vector: at least {REDIS_KEY_SHARD_ID_FIELDS_NUM} expected"); + assert!(values[0] == "id", "redis.res[0] != id, unexpected."); + assert!(values[2] == "server_ip", "redis.res[2] != server_ip, unexpected."); + assert!(values[4] == "server_port", "redis.res[4] != server_port, unexpected."); + assert!(values[6] == "assigned_groups_num", "redis.res[6] != assigned_groups_num, unexpected."); + assert!(values[8] == "capacity", "redis.res[8] != capacity, unexpected."); + assert!(values[10] == "threshold", "redis.res[10] != threshold, unexpected."); + + CGWREDISDBShard { + id: values[1].parse::().unwrap(), + server_ip: values[3].parse::().unwrap(), + server_port: values[5].parse::().unwrap(), + assigned_groups_num: values[7].parse::().unwrap(), + capacity: values[9].parse::().unwrap(), + threshold: values[11].parse::().unwrap(), + } + } +} + +impl Into> for CGWREDISDBShard { + fn into(self) -> Vec { + vec!["id".to_string(), self.id.to_string(), + "server_ip".to_string(), self.server_ip.to_string(), + "server_port".to_string(), self.server_port.to_string(), + "assigned_groups_num".to_string(), self.assigned_groups_num.to_string(), + "capacity".to_string(), self.capacity.to_string(), + "threshold".to_string(), self.threshold.to_string()] + } +} + +#[derive(Clone)] +pub struct CGWRemoteConfig { + pub remote_id: i32, + pub server_ip: Ipv4Addr, + pub server_port: u16, +} + +impl CGWRemoteConfig { + pub fn new(id: i32, ip_conf: Ipv4Addr, port: u16) -> Self { + CGWRemoteConfig { + remote_id: id, + server_ip: ip_conf, + server_port: port, + } + } + pub fn to_socket_addr(&self) -> SocketAddr { + SocketAddr::new(std::net::IpAddr::V4(self.server_ip), self.server_port) + } +} + +#[derive(Clone)] +pub struct CGWRemoteIface { + pub shard: CGWREDISDBShard, + client: CGWRemoteClient, +} + +#[derive(Clone)] +pub struct CGWRemoteDiscovery { + db_accessor: Arc, + redis_client: redis_async::client::paired::PairedConnection, + gid_to_cgw_cache: Arc>>, + remote_cgws_map: Arc>>, + local_shard_id: i32, +} + +impl CGWRemoteDiscovery { + pub async fn new(app_args: &AppArgs) -> Self { + let rc = CGWRemoteDiscovery { + db_accessor: Arc::new(CGWDBAccessor::new(app_args).await), + redis_client: redis_async::client::paired::paired_connect( + app_args.redis_db_ip.to_string(), + app_args.redis_db_port).await.unwrap(), + gid_to_cgw_cache: Arc::new(RwLock::new(HashMap::new())), + local_shard_id: app_args.cgw_id, + remote_cgws_map: Arc::new(RwLock::new(HashMap::new())), + }; + + let _ = rc.sync_gid_to_cgw_map().await; + let _ = rc.sync_remote_cgw_map().await; + + if let None = rc.remote_cgws_map.read().await.get(&rc.local_shard_id) { + let redisdb_shard_info = CGWREDISDBShard { + id: app_args.cgw_id, + server_ip: std::net::IpAddr::V4(app_args.grpc_ip.clone()), + server_port: u16::try_from(app_args.grpc_port).unwrap(), + assigned_groups_num: 0i32, + capacity: 1000i32, + threshold: 50i32, + }; + let redis_req_data: Vec = redisdb_shard_info.into(); + + let _ = rc.redis_client.send::( + resp_array!["DEL", + format!("{REDIS_KEY_SHARD_ID_PREFIX}{}", app_args.cgw_id)]).await; + + if let Err(e) = rc.redis_client.send::( + resp_array!["HSET", format!("{REDIS_KEY_SHARD_ID_PREFIX}{}", app_args.cgw_id)] + .append(redis_req_data)).await { + panic!("Failed to create record about shard in REDIS, e:{e}"); + } + + let _ = rc.sync_remote_cgw_map().await; + } + + debug!("Found {} remote CGWs:", rc.remote_cgws_map.read().await.len() - 1); + + for (key, val) in rc.remote_cgws_map.read().await.iter() { + if val.shard.id == rc.local_shard_id { + continue; + } + debug!("Shard #{}, IP {}:{}", val.shard.id, val.shard.server_ip, val.shard.server_port); + } + + rc + } + + pub async fn sync_gid_to_cgw_map(&self) { + let mut lock = self.gid_to_cgw_cache.write().await; + + // Clear hashmap + lock.clear(); + + let redis_keys: Vec = + match self.redis_client.send::>( + resp_array!["KEYS", format!("{}*", REDIS_KEY_GID)]).await { + Err(e) => { + panic!("Failed to get KEYS list from REDIS, e:{e}"); + }, + Ok(r) => r + }; + + for key in redis_keys { + let gid : i32 = match self.redis_client.send::( + resp_array!["HGET", &key, REDIS_KEY_GID_VALUE_GID]).await { + Ok(res) => { + match res.parse::() { + Ok(res) => res, + Err(e) => { + warn!("Found proper key '{key}' entry, but failed to parse GID from it:\n{e}"); + continue; + } + } + }, + Err(e) => { + warn!("Found proper key '{key}' entry, but failed to fetch GID from it:\n{e}"); + continue; + } + }; + let shard_id : i32 = match self.redis_client.send::( + resp_array!["HGET", &key, REDIS_KEY_GID_VALUE_SHARD_ID]).await { + Ok(res) => { + match res.parse::() { + Ok(res) => res, + Err(e) => { + warn!("Found proper key '{key}' entry, but failed to parse SHARD_ID from it:\n{e}"); + continue; + } + } + }, + Err(e) => { + warn!("Found proper key '{key}' entry, but failed to fetch SHARD_ID from it:\n{e}"); + continue; + } + }; + + debug!("Found group {key}, gid: {gid}, shard_id: {shard_id}"); + + match lock.insert(gid, shard_id) { + None => continue, + Some(v) => warn!("Populated gid_to_cgw_map with previous value being alerady set, unexpected") + } + } + debug!("Found total {} groups with their respective owners", lock.len()); + } + + async fn sync_remote_cgw_map(&self) -> Result<(), &'static str> { + let mut lock = self.remote_cgws_map.write().await; + + // Clear hashmap + lock.clear(); + + let redis_keys: Vec = + match self.redis_client.send::>( + resp_array!["KEYS", format!("{}*", REDIS_KEY_SHARD_ID_PREFIX)]).await { + Err(e) => { + warn!("Failed to get cgw shard KEYS list from REDIS, e:{e}"); + return Err("Remote CGW Shards list fetch from REDIS failed"); + }, + Ok(r) => r + }; + + for key in redis_keys { + match self.redis_client.send::>( + resp_array!["HGETALL", &key]).await { + Ok(res) => { + let shrd: CGWREDISDBShard = match CGWREDISDBShard::try_from(res) { + Ok(v) => v, + Err(e) => { + warn!("Failed to parse CGWREDISDBShard, {key}"); + continue; + } + }; + + let endpoint_str = + String::from("http://") + + &shrd.server_ip.to_string() + + ":" + + &shrd.server_port.to_string(); + let cgw_iface = CGWRemoteIface { + shard: shrd, + client: CGWRemoteClient::new(endpoint_str), + }; + lock.insert(cgw_iface.shard.id, cgw_iface); + }, + Err(e) => { + warn!("Found proper key '{key}' entry, but failed to fetch Shard info from it:\n{e}"); + continue; + } + } + } + + Ok(()) + } + + pub async fn get_infra_group_owner_id(&self, gid: i32) -> Option { + // try to use internal cache first + if let Some(id) = self.gid_to_cgw_cache.read().await.get(&gid) { + return Some(*id); + } + + // then try to use redis + self.sync_gid_to_cgw_map().await; + + if let Some(id) = self.gid_to_cgw_cache.read().await.get(&gid) { + return Some(*id); + } + + None + } + + async fn increment_cgw_assigned_groups_num(&self, id: i32) -> Result<(), &'static str> { + debug!("inc {id}"); + /* + if let Err(e) = self.redis_client.send::( + */ + match self.redis_client.send::( + resp_array!["HINCRBY", + format!("{}{id}", REDIS_KEY_SHARD_ID_PREFIX), + REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM, + "1"]).await { + Ok(v) => { + debug!("ret {:?}", v); + }, + Err(e) => { + warn!("Failed to increment CGW{id} assigned group num count, e:{e}"); + return Err("Failed to increment assigned group num count"); + } + } + Ok(()) + } + + async fn decrement_cgw_assigned_groups_num(&self, id: i32) -> Result<(), &'static str> { + if let Err(e) = self.redis_client.send::( + resp_array!["HINCRBY", + format!("{}{id}", REDIS_KEY_SHARD_ID_PREFIX), + REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM, + "-1"]).await { + warn!("Failed to decrement CGW{id} assigned group num count, e:{e}"); + return Err("Failed to decrement assigned group num count"); + } + Ok(()) + } + + async fn get_infra_group_cgw_assignee(&self) -> Result { + let lock = self.remote_cgws_map.read().await; + let mut hash_vec: Vec<(&i32, &CGWRemoteIface)> = lock.iter().collect(); + + hash_vec.sort_by(|a, b| b.1.shard.assigned_groups_num.cmp(&a.1.shard.assigned_groups_num)); + + for x in hash_vec { + debug!("id_{} capacity {} t {} assigned {}", + x.1.shard.id, x.1.shard.capacity, + x.1.shard.threshold, + x.1.shard.assigned_groups_num); + let max_capacity: i32 = x.1.shard.capacity + x.1.shard.threshold; + if x.1.shard.assigned_groups_num + 1 <= max_capacity { + debug!("Found CGW shard to assign group to (id {})", x.1.shard.id); + return Ok(x.1.shard.id); + } + + } + + warn!("Every available CGW is exceeding capacity+threshold limit, using least loaded one..."); + if let Some(least_loaded_cgw) = lock.iter() + .min_by(|a, b| a.1.shard.assigned_groups_num.cmp(&b.1.shard.assigned_groups_num)) + .map(|(k, _v)| _v) { + warn!("Found least loaded CGW id: {}", least_loaded_cgw.shard.id); + return Ok(least_loaded_cgw.shard.id); + } + + return Err("Unexpected: Failed to find the least loaded CGW shard"); + } + + async fn assign_infra_group_to_cgw(&self, gid: i32) -> Result { + // Delete key (if exists), recreate with new owner + let _ = self.deassign_infra_group_to_cgw(gid).await; + + let dst_cgw_id: i32 = match self.get_infra_group_cgw_assignee().await { + Ok(v) => v, + Err(e) => { + warn!("Failed to assign {gid} to any shard, reason:{e}"); + return Err(e); + } + }; + + if let Err(e) = self.redis_client.send::( + resp_array!["HSET", + format!("{REDIS_KEY_GID}{gid}"), + REDIS_KEY_GID_VALUE_GID, + gid.to_string(), + REDIS_KEY_GID_VALUE_SHARD_ID, + dst_cgw_id.to_string()]).await { + error!("Failed to update REDIS gid{gid} owner to shard{dst_cgw_id}, e:{e}"); + return Err("Hot-cache (REDIS DB) update owner failed"); + } + + let mut lock = self.gid_to_cgw_cache.write().await; + lock.insert(gid, dst_cgw_id); + + debug!("REDIS: assigned gid{gid} to shard{dst_cgw_id}"); + + Ok(dst_cgw_id) + } + + pub async fn deassign_infra_group_to_cgw(&self, gid: i32) -> Result<(), &'static str> { + if let Err(e) = self.redis_client.send::( + resp_array!["DEL", format!("{REDIS_KEY_GID}{gid}")]).await { + error!("Failed to deassigned REDIS gid{gid} owner, e:{e}"); + return Err("Hot-cache (REDIS DB) deassign owner failed"); + } + + debug!("REDIS: deassigned gid{gid} from controlled CGW"); + + let mut lock = self.gid_to_cgw_cache.write().await; + lock.remove(&gid); + + Ok(()) + } + + pub async fn create_infra_group(&self, g: &CGWDBInfrastructureGroup) -> Result { + //TODO: transaction-based insert/assigned_group_num update (DB) + let rc = self.db_accessor.insert_new_infra_group(g).await; + if let Err(e) = rc { + return Err(e); + } + + let shard_id: i32 = match self.assign_infra_group_to_cgw(g.id).await { + Ok(v) => v, + Err(e) => { + let _ = self.db_accessor.delete_infra_group(g.id).await; + return Err("Assign group to CGW shard failed"); + } + }; + + let rc = self.increment_cgw_assigned_groups_num(shard_id).await; + if let Err(e) = rc { + return Err(e); + } + + Ok(shard_id) + } + + pub async fn destroy_infra_group(&self, gid: i32) -> Result<(), &'static str> { + let cgw_id: Option = self.get_infra_group_owner_id(gid).await; + if let Some(id) = cgw_id { + let _ = self.deassign_infra_group_to_cgw(gid).await; + let _ = self.decrement_cgw_assigned_groups_num(id).await; + } + + //TODO: transaction-based insert/assigned_group_num update (DB) + let rc = self.db_accessor.delete_infra_group(gid).await; + if let Err(e) = rc { + return Err(e); + } + + Ok(()) + } + + pub async fn create_ifras_list(&self, gid: i32, infras: Vec) -> Result<(), Vec> { + // TODO: assign list to shards; currently - only created bulk, no assignment + let mut futures = Vec::with_capacity(infras.len()); + // Results store vec of MACs we failed to add + let mut failed_infras: Vec = Vec::with_capacity(futures.len()); + for x in infras.iter() { + let db_accessor_clone = self.db_accessor.clone(); + let infra = CGWDBInfra { + mac: MacAddress::parse_str(&x).unwrap(), + infra_group_id: gid, + }; + futures.push( + tokio::spawn( + async move { + if let Err(_) = db_accessor_clone.insert_new_infra(&infra).await { + Err(infra.mac.to_string(eui48::MacAddressFormat::HexString)) + } else { + Ok(()) + } + })); + } + + for (i, future) in futures.iter_mut().enumerate() { + match future.await { + Ok(res) => { + if let Err(mac) = res { + failed_infras.push(mac); + } + }, + Err(_) => { + failed_infras.push(infras[i].clone()); + } + } + } + + if failed_infras.len() > 0 { + return Err(failed_infras); + } + + Ok(()) + } + + pub async fn destroy_ifras_list(&self, gid: i32, infras: Vec) -> Result<(), Vec> { + let mut futures = Vec::with_capacity(infras.len()); + // Results store vec of MACs we failed to add + let mut failed_infras: Vec = Vec::with_capacity(futures.len()); + for x in infras.iter() { + let db_accessor_clone = self.db_accessor.clone(); + let mac = MacAddress::parse_str(&x).unwrap(); + futures.push( + tokio::spawn( + async move { + if let Err(_) = db_accessor_clone.delete_infra(mac).await { + Err(mac.to_string(eui48::MacAddressFormat::HexString)) + } else { + Ok(()) + } + })); + } + + for (i, future) in futures.iter_mut().enumerate() { + match future.await { + Ok(res) => { + if let Err(mac) = res { + failed_infras.push(mac); + } + }, + Err(_) => { + failed_infras.push(infras[i].clone()); + } + } + } + + if failed_infras.len() > 0 { + return Err(failed_infras); + } + + Ok(()) + } + + pub async fn relay_request_stream_to_remote_cgw( + &self, + shard_id: i32, + stream: Vec<(String, String)>) + -> Result<(), ()> { + // try to use internal cache first + if let Some(cl) = self.remote_cgws_map.read().await.get(&shard_id) { + if let Err(()) = cl.client.relay_request_stream(stream).await { + return Err(()) + } + + return Ok(()); + } + + // then try to use redis + let _ = self.sync_remote_cgw_map().await; + + if let Some(cl) = self.remote_cgws_map.read().await.get(&shard_id) { + if let Err(()) = cl.client.relay_request_stream(stream).await { + return Err(()) + } + return Ok(()); + } + + error!("No suitable CGW instance #{shard_id} was discovered, cannot relay msg"); + return Err(()); + } +} diff --git a/src/cgw_remote_server.rs b/src/cgw_remote_server.rs new file mode 100644 index 0000000..d25b27a --- /dev/null +++ b/src/cgw_remote_server.rs @@ -0,0 +1,93 @@ +use crate::{ + AppArgs, +}; + +pub mod cgw_remote { + tonic::include_proto!("cgw.remote"); +} + +use tonic::{transport::Server, Request, Response, Status}; + +use cgw_remote::{ + EnqueueRequest, + EnqueueResponse, + remote_server::{ + RemoteServer, + Remote, + }, +}; + +use tokio::{ + time::{ + Duration, + }, +}; + +use crate::cgw_remote_discovery::{ + CGWRemoteConfig, +}; + +use crate::cgw_connection_server:: { + CGWConnectionServer, +}; + +use tokio_stream::StreamExt; + +use std::{ + sync::{ + Arc, + }, +}; + +struct CGWRemote { + cgw_srv: Arc, +} + +#[tonic::async_trait] +impl Remote for CGWRemote { + async fn enqueue_nbapi_request_stream(&self, request: Request>, ) -> Result, Status> { + let mut rq_stream = request.into_inner(); + while let Some(rq) = rq_stream.next().await { + let rq = rq.unwrap(); + self.cgw_srv.enqueue_mbox_relayed_message_to_cgw_server(rq.key, rq.req).await; + } + + let reply = cgw_remote::EnqueueResponse { + ret: 0, + msg: "DONE".to_string(), + }; + Ok(Response::new(reply)) + } +} + +pub struct CGWRemoteServer { + cfg: CGWRemoteConfig, +} + +impl CGWRemoteServer { + pub fn new(app_args: &AppArgs) -> Self { + let remote_cfg = CGWRemoteConfig::new(app_args.cgw_id, app_args.grpc_ip, app_args.grpc_port); + let remote_server = CGWRemoteServer { + cfg: remote_cfg, + }; + remote_server + } + pub async fn start(&self, srv: Arc) { + // GRPC server + // TODO: use CGWRemoteServerConfig wrap + let icgw_serv = CGWRemote{ + cgw_srv: srv, + }; + let grpc_srv = Server::builder() + .tcp_keepalive(Some(Duration::from_secs(7200))) + .http2_keepalive_timeout(Some(Duration::from_secs(7200))) + .http2_keepalive_interval(Some(Duration::from_secs(7200))) + .http2_max_pending_accept_reset_streams(Some(1000)) + .add_service(RemoteServer::new(icgw_serv)); + + info!("Starting GRPC server id {} - listening at {}:{}", self.cfg.remote_id, self.cfg.server_ip, self.cfg.server_port); + let res = grpc_srv.serve(self.cfg.to_socket_addr()).await; + error!("grpc server returned {:?}", res); + // end of GRPC server build / start declaration + } +} diff --git a/src/localhost.crt b/src/localhost.crt new file mode 100755 index 0000000..03af12f --- /dev/null +++ b/src/localhost.crt @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u +eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTE2MDgxMzE2MDcwNFoX +DTIyMDIwMzE2MDcwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCpVhh1/FNP2qvWenbZSghari/UThwe +dynfnHG7gc3JmygkEdErWBO/CHzHgsx7biVE5b8sZYNEDKFojyoPHGWK2bQM/FTy +niJCgNCLdn6hUqqxLAml3cxGW77hAWu94THDGB1qFe+eFiAUnDmob8gNZtAzT6Ky +b/JGJdrEU0wj+Rd7wUb4kpLInNH/Jc+oz2ii2AjNbGOZXnRz7h7Kv3sO9vABByYe +LcCj3qnhejHMqVhbAT1MD6zQ2+YKBjE52MsQKU/xhUpu9KkUyLh0cxkh3zrFiKh4 +Vuvtc+n7aeOv2jJmOl1dr0XLlSHBlmoKqH6dCTSbddQLmlK7dms8vE01AgMBAAGj +gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFMeUzGYV +bXwJNQVbY1+A8YXYZY8pMEIGA1UdIwQ7MDmAFJvEsUi7+D8vp8xcWvnEdVBGkpoW +oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO +dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0 +MA0GCSqGSIb3DQEBCwUAA4IBgQBsk5ivAaRAcNgjc7LEiWXFkMg703AqDDNx7kB1 +RDgLalLvrjOfOp2jsDfST7N1tKLBSQ9bMw9X4Jve+j7XXRUthcwuoYTeeo+Cy0/T +1Q78ctoX74E2nB958zwmtRykGrgE/6JAJDwGcgpY9kBPycGxTlCN926uGxHsDwVs +98cL6ZXptMLTR6T2XP36dAJZuOICSqmCSbFR8knc/gjUO36rXTxhwci8iDbmEVaf +BHpgBXGU5+SQ+QM++v6bHGf4LNQC5NZ4e4xvGax8ioYu/BRsB/T3Lx+RlItz4zdU +XuxCNcm3nhQV2ZHquRdbSdoyIxV5kJXel4wCmOhWIq7A2OBKdu5fQzIAzzLi65EN +RPAKsKB4h7hGgvciZQ7dsMrlGw0DLdJ6UrFyiR5Io7dXYT/+JP91lP5xsl6Lhg9O +FgALt7GSYRm2cZdgi9pO9rRr83Br1VjQT1vHz6yoZMXSqc4A2zcN2a2ZVq//rHvc +FZygs8miAhWPzqnpmgTj1cPiU1M= +-----END CERTIFICATE----- diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..9fa7834 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,275 @@ +#![warn(rust_2018_idioms)] +mod cgw_connection_server; +mod cgw_nb_api_listener; +mod cgw_connection_processor; +mod cgw_remote_discovery; +mod cgw_remote_server; +mod cgw_remote_client; +mod cgw_db_accessor; + +#[macro_use] +extern crate log; + +use tokio::{ + net::{ + TcpListener, + }, + time::{ + sleep, + Duration, + }, + runtime::{ + Builder, + Handle, + Runtime, + }, +}; + +use native_tls::Identity; +use std::{ + net::{ + Ipv4Addr, + SocketAddr, + }, + sync::{ + Arc, + }, +}; + +use rlimit::{ + setrlimit, + Resource, +}; + +use cgw_connection_server::{ + CGWConnectionServer, +}; + +use cgw_remote_server::{ + CGWRemoteServer, +}; + +use clap::{ + ValueEnum, + Parser, +}; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum AppCoreLogLevel { + /// Print debug-level messages and above + Debug, + /// Print info-level messages and above + /// + /// + Info, +} + +/// CGW server +#[derive(Parser, Clone)] +#[command(version, about, long_about = None)] +pub struct AppArgs { + /// CGW unique identifier (u64) + #[arg(short, long, default_value_t = 0)] + cgw_id: i32, + + /// Number of thread in a threadpool dedicated for handling secure websocket connections + #[arg(short, long, default_value_t = 4)] + wss_t_num: usize, + /// Loglevel of application + #[arg(value_enum, default_value_t = AppCoreLogLevel::Debug)] + log_level: AppCoreLogLevel, + + /// IP to listen for incoming WSS connection + #[arg(long, default_value_t = Ipv4Addr::new(0, 0, 0, 0))] + wss_ip: Ipv4Addr, + /// PORT to listen for incoming WSS connection + #[arg(long, default_value_t = 15002)] + wss_port: u16, + + /// IP to listen for incoming GRPC connection + #[arg(long, default_value_t = Ipv4Addr::new(0, 0, 0, 0))] + grpc_ip: Ipv4Addr, + /// PORT to listen for incoming GRPC connection + #[arg(long, default_value_t = 50051)] + grpc_port: u16, + + /// IP to connect to KAFKA broker + #[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))] + kafka_ip: Ipv4Addr, + /// PORT to connect to KAFKA broker + #[arg(long, default_value_t = 9092)] + kafka_port: u16, + /// KAFKA topic from where to consume messages + #[arg(long, default_value_t = String::from("CnC"))] + kafka_consume_topic: String, + /// KAFKA topic where to produce messages + #[arg(long, default_value_t = String::from("CnC_Res"))] + kafka_produce_topic: String, + + /// IP to connect to DB (PSQL) + #[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))] + db_ip: Ipv4Addr, + /// PORT to connect to DB (PSQL) + #[arg(long, default_value_t = 5432)] + db_port: u16, + /// DB name to connect to in DB (PSQL) + #[arg(long, default_value_t = String::from("cgw"))] + db_name: String, + /// DB user name use with connection to in DB (PSQL) + #[arg(long, default_value_t = String::from("cgw"))] + db_username: String, + /// DB user password use with connection to in DB (PSQL) + #[arg(long, default_value_t = String::from("123"))] + db_password: String, + + /// IP to connect to DB (REDIS) + #[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))] + redis_db_ip: Ipv4Addr, + /// PORT to connect to DB (REDIS) + #[arg(long, default_value_t = 6379)] + redis_db_port: u16, +} + +pub struct AppCore { + cgw_server: Arc, + main_runtime_handle: Arc, + grpc_server_runtime_handle: Arc, + conn_ack_runtime_handle: Arc, + args: AppArgs, +} + +impl AppCore { + async fn new(app_args: AppArgs) -> Self { + Self::setup_app(&app_args); + let current_runtime = Arc::new(Handle::current()); + + let c_ack_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(1) + .thread_name("cgw-c-ack") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + let rpc_runtime_handle = Arc::new(Builder::new_multi_thread() + .worker_threads(1) + .thread_name("grpc-recv-t") + .thread_stack_size(1 * 1024 * 1024) + .enable_all() + .build() + .unwrap()); + let app_core = AppCore { + cgw_server: CGWConnectionServer::new(&app_args).await, + main_runtime_handle: current_runtime, + conn_ack_runtime_handle: c_ack_runtime_handle, + args: app_args, + grpc_server_runtime_handle: rpc_runtime_handle + }; + app_core + } + + fn setup_app(args: &AppArgs) { + let nofile_rlimit = Resource::NOFILE.get().unwrap(); + println!("{:?}", nofile_rlimit); + let nofile_hard_limit = nofile_rlimit.1; + assert!(setrlimit(Resource::NOFILE, nofile_hard_limit, nofile_hard_limit).is_ok()); + let nofile_rlimit = Resource::NOFILE.get().unwrap(); + println!("{:?}", nofile_rlimit); + + match args.log_level { + AppCoreLogLevel::Debug => ::std::env::set_var("RUST_LOG", "ucentral_cgw=debug"), + AppCoreLogLevel::Info => ::std::env::set_var("RUST_LOG", "ucentral_cgw=info"), + } + env_logger::init(); + } + + async fn run(self: Arc) { + let main_runtime_handle = self.main_runtime_handle.clone(); + let core_clone = self.clone(); + + let cgw_remote_server = CGWRemoteServer::new(&self.args); + let cgw_srv_clone = self.cgw_server.clone(); + self.grpc_server_runtime_handle.spawn(async move { + debug!("cgw_remote_server.start entry"); + cgw_remote_server.start(cgw_srv_clone).await; + debug!("cgw_remote_server.start exit"); + }); + + main_runtime_handle.spawn(async move { + server_loop(core_clone).await + }); + + // TODO: + // Add signal processing and etcetera app-related handlers. + loop { + sleep(Duration::from_millis(5000)).await; + } + } +} + +// TODO: a method of an object (TlsAcceptor? CGWConnectionServer?), not a plain function +async fn server_loop(app_core: Arc) -> () { + debug!("sever_loop entry"); + + debug!("Starting WSS server, listening at {}:{}", app_core.args.wss_ip, app_core.args.wss_port); + // Bind the server's socket + let sockaddraddr = SocketAddr::new(std::net::IpAddr::V4(app_core.args.wss_ip), app_core.args.wss_port); + let listener: Arc = match TcpListener::bind(sockaddraddr).await { + Ok(listener) => Arc::new(listener), + Err(e) => panic!("listener bind failed {e}"), + }; + + info!("Started WSS server."); + // Create the TLS acceptor. + // TODO: custom acceptor + let der = include_bytes!("localhost.crt"); + let key = include_bytes!("localhost.key"); + let cert = match Identity::from_pkcs8(der, key) { + Ok(cert) => cert, + Err(e) => panic!("Cannot create SSL identity from supplied cert\n{e}") + }; + + let tls_acceptor = + tokio_native_tls::TlsAcceptor::from( + match native_tls::TlsAcceptor::builder(cert).build() { + Ok(builder) => builder, + Err(e) => panic!("Cannot create SSL-acceptor from supplied cert\n{e}") + }); + + // Spawn explicitly in main thread: created task accepts connection, + // but handling is spawned inside another threadpool runtime + let app_core_clone = app_core.clone(); + let _ = app_core.main_runtime_handle.spawn(async move { + let mut conn_idx: i64 = 0; + loop { + let app_core_clone = app_core_clone.clone(); + let cgw_server_clone = app_core_clone.cgw_server.clone(); + let tls_acceptor_clone = tls_acceptor.clone(); + + // Asynchronously wait for an inbound socket. + let (socket, remote_addr) = match listener.accept().await { + Ok((sock, addr)) => (sock, addr), + Err(e) => { + error!("Failed to Accept conn {e}\n"); + continue; + } + }; + + // TODO: we control tls_acceptor, thus at this stage it's our responsibility + // to provide underlying certificates common name inside the ack_connection func + // (CN == mac address of device) + app_core_clone.conn_ack_runtime_handle.spawn(async move { + cgw_server_clone.ack_connection(socket, tls_acceptor_clone, remote_addr, conn_idx).await; + }); + + conn_idx += 1; + } + }).await; +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + let args = AppArgs::parse(); + let app = Arc::new(AppCore::new(args).await); + + app.run().await; +} diff --git a/src/proto/cgw.proto b/src/proto/cgw.proto new file mode 100644 index 0000000..5c41bf2 --- /dev/null +++ b/src/proto/cgw.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package cgw.remote; + +service Remote { + rpc EnqueueNBAPIRequestStream(stream EnqueueRequest) returns (EnqueueResponse); +} + +message EnqueueRequest { + string req = 1; + string key = 2; +} + +message EnqueueResponse { + uint64 ret = 1; + string msg = 2; +} diff --git a/utils/cert_generator/.gitignore b/utils/cert_generator/.gitignore new file mode 100644 index 0000000..df91287 --- /dev/null +++ b/utils/cert_generator/.gitignore @@ -0,0 +1 @@ +certs/ diff --git a/utils/cert_generator/README.md b/utils/cert_generator/README.md new file mode 100644 index 0000000..c2c2d03 --- /dev/null +++ b/utils/cert_generator/README.md @@ -0,0 +1,18 @@ +# Generate Certificates + +``` +# generate CA certs +$ ./generate_certs.sh -a + +# generate server certs +$ ./generate_certs.sh -s + +# generate 10 client certs +$ ./generate_certs.sh -c 10 + +# generate 10 client certs with MAC addr AA:* +$ ./generate_certs.sh -c 10 -m AA:XX:XX:XX:XX:XX +``` + +The certificates will be available in `./certs/`. + diff --git a/utils/cert_generator/ca.conf b/utils/cert_generator/ca.conf new file mode 100644 index 0000000..65fe93a --- /dev/null +++ b/utils/cert_generator/ca.conf @@ -0,0 +1,19 @@ + +[req] +distinguished_name=req +[ca] +basicConstraints=CA:true +keyUsage=keyCertSign,cRLSign +subjectKeyIdentifier=hash +[client] +basicConstraints=CA:FALSE +extendedKeyUsage=clientAuth +keyUsage=digitalSignature,keyEncipherment +[server] +basicConstraints=CA:FALSE +extendedKeyUsage=serverAuth +keyUsage=digitalSignature,keyEncipherment +subjectAltName=@sans +[sans] +DNS.1=localhost +IP.1=127.0.0.1 diff --git a/utils/cert_generator/generate_certs.sh b/utils/cert_generator/generate_certs.sh new file mode 100755 index 0000000..09370b1 --- /dev/null +++ b/utils/cert_generator/generate_certs.sh @@ -0,0 +1,161 @@ +#!/bin/bash +TEMPLATE="02:XX:XX:XX:XX:XX" +HEXCHARS="0123456789ABCDEF" +CONF_FILE=ca.conf +CERT_DIR=certs +CA_DIR=$CERT_DIR/ca +CA_CERT=$CA_DIR/ca.crt +CA_KEY=$CA_DIR/ca.key +SERVER_DIR=$CERT_DIR/server +CLIENT_DIR=$CERT_DIR/client +METHOD_FAST=y + +usage() +{ + echo "Usage: $0 [options]" + echo + echo "options:" + echo "-h" + echo -e "\tprint this help" + echo "-a" + echo -e "\tgenerate CA key and certificate" + echo "-s" + echo -e "\tgenerate server key and certificate; sign the certificate" + echo -e "\tusing the CA certificate" + echo "-c NUMBER" + echo -e "\tgenerate *NUMBER* of client keys and certificates; sign" + echo -e "\tall of the certificates using the CA certificate" + echo "-m MASK" + echo -e "\tspecify custom MAC addr mask" + echo "-o" + echo -e "\tslow mode" +} + +rand_mac() +{ + local MAC=$TEMPLATE + while [[ $MAC =~ "X" || $MAC =~ "x" ]] + do + MAC=${MAC/[xX]/${HEXCHARS:$(( $RANDOM % 16 )):1}} + done + echo $MAC +} + +gen_cert() +{ + local req_file=$(mktemp) + local pem=$(mktemp) + local type=$1 + local cn=$2 + local key=$3 + local cert=$4 + # generate key and request to sign + openssl req -config $CONF_FILE -x509 -nodes -newkey rsa:4096 -sha512 -days 365 \ + -extensions $type -subj "/CN=$cn" -out $req_file -keyout $key &> /dev/null + # sign certificate + openssl x509 -extfile $CONF_FILE -CA $CA_CERT -CAkey $CA_KEY -CAcreateserial -sha512 -days 365 \ + -in $req_file -out $pem + if [ $? == "0" ] + then + cat $pem $CA_CERT > $cert + else + >&2 echo Failed to generate certificate + rm $key + fi + rm $req_file + rm $pem +} + +gen_client_batch() +{ + batch=$(($1 / 100)) + sync_count=$(($1 / 10)) + baseline=$(ps aux | grep openssl | wc -l) + for (( c=1; c<=$1; c++ )) + do + mac=$(rand_mac) + gen_cert client $mac $CLIENT_DIR/$mac.key $CLIENT_DIR/$mac.crt & + if [ "$(( $c % $batch ))" -eq "0" ] + then + echo $(($c/$batch))% + fi + if [ "$(( $c % $sync_count ))" -eq "0" ] + then + until [ $(ps aux | grep openssl | wc -l) -eq "$baseline" ]; do sleep 1; done + fi + done +} + +gen_client() +{ + for x in $(seq $1) + do + echo $x + mac=$(rand_mac) + gen_cert client $mac $CLIENT_DIR/$mac.key $CLIENT_DIR/$mac.crt + done +} + +while getopts "ac:shm:o" arg; do +case $arg in +a) + GEN_CA=y + ;; +s) + GEN_SER=y + ;; +c) + GEN_CLI=y + NUM_CERTS=$OPTARG + if [ $NUM_CERTS -lt 100 ] + then + METHOD_FAST=n + fi + ;; +m) + TEMPLATE=$OPTARG + ;; +o) + METHOD_FAST=n + ;; +h) + usage + exit 0 + ;; +*) + usage + exit 1 + ;; +esac +done + +if [ "$GEN_CA" == "y" ] +then + echo Generating root CA certificate + mkdir -p $CA_DIR + openssl req -config $CONF_FILE -x509 -nodes -newkey rsa:4096 -sha512 -days 365 \ + -extensions ca -subj "/CN=CA" -out $CA_CERT -keyout $CA_KEY &> /dev/null +fi + +if [ "$GEN_SER" == "y" ] +then + mkdir -p $SERVER_DIR + echo Generating server certificate + gen_cert server localhost $SERVER_DIR/gw.key $SERVER_DIR/gw.crt +fi + +if [ "$GEN_CLI" == "y" ] +then + echo Generating $NUM_CERTS client certificates + mkdir -p $CLIENT_DIR + if [ $METHOD_FAST == "y" ] + then + # because of race condition some of the certificates might fail to generate + # but this is ~15 times faster than generating certificates one by one + gen_client_batch $NUM_CERTS + else + gen_client $NUM_CERTS + fi +fi + +echo done diff --git a/utils/client_simulator/.gitignore b/utils/client_simulator/.gitignore new file mode 100644 index 0000000..9c99192 --- /dev/null +++ b/utils/client_simulator/.gitignore @@ -0,0 +1,6 @@ +bin/ +lib/ +lib64 +include/ +*.cfg +__pycache__ diff --git a/utils/client_simulator/Dockerfile b/utils/client_simulator/Dockerfile new file mode 100644 index 0000000..798514c --- /dev/null +++ b/utils/client_simulator/Dockerfile @@ -0,0 +1,10 @@ +FROM ubuntu + +COPY ./requirements.txt /tmp/requirements.txt +WORKDIR /opt/client_simulator + +RUN apt-get update -q -y && apt-get -q -y --no-install-recommends install \ + python3 \ + python3-pip +RUN python3 -m pip install -r /tmp/requirements.txt + diff --git a/utils/client_simulator/Makefile b/utils/client_simulator/Makefile new file mode 100644 index 0000000..494a995 --- /dev/null +++ b/utils/client_simulator/Makefile @@ -0,0 +1,34 @@ +IMG_NAME=cgw-client-sim +CONTAINER_NAME=cgw_client_sim +MAC?=XX:XX:XX:XX:XX:XX +COUNT?=1000 +URL=wss://localhost:15002 +CA_CERT_PATH?=$(PWD)/../cert_generator/certs/ca +CLIENT_CERT_PATH?=$(PWD)/../cert_generator/certs/client +MSG_INTERVAL?=10 +MSG_SIZE?=1000 + +.PHONY: build spawn stop start + +build: + docker build -t ${IMG_NAME} . + +spawn: + docker run --name "${CONTAINER_NAME}_${COUNT}_$(subst :,-,$(MAC))" \ + -d --rm --network host \ + -v $(PWD):/opt/client_simulator \ + -v ${CA_CERT_PATH}:/etc/ca \ + -v ${CLIENT_CERT_PATH}:/etc/certs \ + ${IMG_NAME} \ + python3 main.py -M ${MAC} -N ${COUNT} -s ${URL} \ + --ca-cert /etc/ca/ca.crt \ + --client-certs-path /etc/certs \ + --msg-interval ${MSG_INTERVAL} \ + --payload-size ${MSG_SIZE} \ + --wait-for-signal + +stop: + docker stop $$(docker ps -q -f name=$(CONTAINER_NAME)) + +start: + docker kill -s SIGUSR1 $$(docker ps -q -f name=$(CONTAINER_NAME)) \ No newline at end of file diff --git a/utils/client_simulator/README.md b/utils/client_simulator/README.md new file mode 100644 index 0000000..139f1a0 --- /dev/null +++ b/utils/client_simulator/README.md @@ -0,0 +1,39 @@ +# Run client simulation + +``` +# run 10 concurrent client simulations +$ ./main.py -s wss://localhost:50001 -n 10 + +# use only specified MAC addrs +$ ./main.py -s wss://localhost:15002 -N 10 -M AA:XX:XX:XX:XX:XX + +# run 10 concurrent simulations with MAC AA:* and 10 concurrent simulations with MAC BB:* +# AA:* and BB:* simulations are run in separate processes +$ ./main.py -s wss://localhost:15002 -N 10 -M AA:XX:XX:XX:XX:XX -M BB:XX:XX:XX:XX:XX +``` + +To stop the simulation use `Ctrl+C`. + +# Run simulation in docker + +``` +$ make +$ make spawn +$ make start +$ make stop + +# specify mac addr range +$ make spawn MAC=11:22:AA:BB:XX:XX + +# specify number of client connections (default is 1000) +$ make spawn COUNT=100 + +# specify server url +$ make spawn URL=wss://localhost:15002 + +# tell all running containers to start connecting to the server +$ make start + +# stop all containers +$ make stop +``` diff --git a/utils/client_simulator/certs b/utils/client_simulator/certs new file mode 120000 index 0000000..edda191 --- /dev/null +++ b/utils/client_simulator/certs @@ -0,0 +1 @@ +../cert_generator/certs/ \ No newline at end of file diff --git a/utils/client_simulator/data/message_templates.json b/utils/client_simulator/data/message_templates.json new file mode 100644 index 0000000..0c197a8 --- /dev/null +++ b/utils/client_simulator/data/message_templates.json @@ -0,0 +1,5 @@ +{"connect": {"jsonrpc":"2.0","method":"connect","params":{"serial":"MAC","firmware":"Rel 1.6 build 1","uuid":1692198868,"capabilities":{"compatible":"x86_64-kvm_x86_64-r0","model":"DellEMC-S5248f-P-25G-DPB","platform":"switch","label_macaddr":"MAC"}}}, + "state": {"jsonrpc": "2.0", "method": "state", "params": {"serial": "MAC","uuid": 1692198868, "request_uuid": null, "state": {}}}, + "reboot_response": {"jsonrpc": "2.0", "result": {"serial": "MAC", "status": {"error": 0, "text": "", "when": 0}, "id": "ID"}}, + "log": {"jsonrpc": "2.0", "method": "log", "params": {"serial": "MAC", "log": "", "severity": 7, "data": {}}} +} diff --git a/utils/client_simulator/main.py b/utils/client_simulator/main.py new file mode 100755 index 0000000..d17b330 --- /dev/null +++ b/utils/client_simulator/main.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +from src.utils import parse_args +from src.simulation_runner import main + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/utils/client_simulator/requirements.txt b/utils/client_simulator/requirements.txt new file mode 100644 index 0000000..02e1832 --- /dev/null +++ b/utils/client_simulator/requirements.txt @@ -0,0 +1 @@ +websockets==12.0 diff --git a/utils/client_simulator/src/__init__.py b/utils/client_simulator/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/client_simulator/src/log.py b/utils/client_simulator/src/log.py new file mode 100644 index 0000000..d9caf46 --- /dev/null +++ b/utils/client_simulator/src/log.py @@ -0,0 +1,57 @@ +import logging + + +TRACE_LEVEL = logging.DEBUG - 5 +TRACE_NAME = "TRACE" + + +class ColoredFormatter(logging.Formatter): + + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + white = "\x1b[37;20m" + cyan = "\x1b[36;20m" + reset = "\x1b[0m" + format = "{asctime}|{levelname}|{threadName}|{funcName}:{lineno}\t{message}" + + FORMATS = { + logging.DEBUG: grey + format + reset, + logging.INFO: cyan + format + reset, + logging.WARNING: yellow + format + reset, + logging.ERROR: red + format + reset, + logging.CRITICAL: bold_red + format + reset, + TRACE_LEVEL: white + format + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, style="{") + return formatter.format(record) + + +def __trace(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'TRACE'. + + logger.trace("periodic log") + """ + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, msg, args, **kwargs) + + +logging.basicConfig(level=logging.INFO, datefmt="%H:%M:%S") + +logging.addLevelName(TRACE_LEVEL, TRACE_NAME) +setattr(logging, TRACE_NAME, TRACE_LEVEL) +setattr(logging.getLoggerClass(), TRACE_NAME.lower(), __trace) + +console = logging.StreamHandler() +console.setLevel(logging.NOTSET) +console.setFormatter(ColoredFormatter()) + +logger = logging.getLogger(__name__) +logger.propagate = False +logger.addHandler(console) +logging.getLogger('websockets.client').setLevel(logging.INFO) diff --git a/utils/client_simulator/src/simulation_runner.py b/utils/client_simulator/src/simulation_runner.py new file mode 100644 index 0000000..0c82591 --- /dev/null +++ b/utils/client_simulator/src/simulation_runner.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +from .utils import get_msg_templates, Args +from .log import logger +from websockets.sync import client +from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError, ConnectionClosed +from typing import List +import multiprocessing +import threading +import resource +import string +import random +import signal +import copy +import time +import json +import ssl +import os +import re + + +class Message: + def __init__(self, mac: str, size: int): + self.templates = get_msg_templates() + self.connect = json.dumps(self.templates["connect"]).replace("MAC", mac) + self.state = json.dumps(self.templates["state"]).replace("MAC", mac) + self.reboot_response = json.dumps(self.templates["reboot_response"]).replace("MAC", mac) + self.log = copy.deepcopy(self.templates["log"]) + self.log["params"]["data"] = {"msg": ''.join(random.choices(string.ascii_uppercase + string.digits, k=size))} + self.log = json.dumps(self.log).replace("MAC", mac) + + @staticmethod + def to_json(msg) -> str: + return json.dumps(msg) + + @staticmethod + def from_json(msg) -> dict: + return json.loads(msg) + + +class Device: + def __init__(self, mac: str, server: str, ca_cert: str, + msg_interval: int, msg_size: int, + client_cert: str, client_key: str, + start_event: multiprocessing.Event, + stop_event: multiprocessing.Event): + self.mac = mac + self.interval = msg_interval + self.messages = Message(self.mac, msg_size) + self.server_addr = server + self.start_event = start_event + self.stop_event = stop_event + self.reboot_time_s = 10 + self._socket = None + + self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.ssl_context.load_cert_chain(client_cert, client_key, "") + self.ssl_context.load_verify_locations(ca_cert) + self.ssl_context.verify_mode = ssl.CERT_REQUIRED + + def send_hello(self, socket: client.ClientConnection): + logger.debug(self.messages.connect) + socket.send(self.messages.connect) + + def send_log(self, socket: client.ClientConnection): + socket.send(self.messages.log) + + def handle_messages(self, socket: client.ClientConnection): + try: + msg = socket.recv(self.interval) + msg = self.messages.from_json(msg) + logger.info(msg) + if msg["method"] == "reboot": + self.handle_reboot(socket, msg) + else: + logger.error(f"Unknown method {msg['method']}") + except TimeoutError: # no messages + pass + except (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed): + logger.critical("Did not expect socket to be closed") + raise + + def handle_reboot(self, socket: client.ClientConnection, msg: dict): + resp = self.messages.from_json(self.messages.reboot_response) + if "id" in msg: + resp["result"]["id"] = msg["id"] + else: + del resp["result"]["id"] + logger.warn("Reboot request is missing 'id' field") + socket.send(self.messages.to_json(resp)) + self.disconnect() + time.sleep(self.reboot_time_s) + self.connect() + self.send_hello(self._socket) + + def connect(self): + if self._socket is None: + self._socket = client.connect(self.server_addr, ssl_context=self.ssl_context, open_timeout=7200) + return self._socket + + def disconnect(self): + if self._socket is not None: + self._socket.close() + self._socket = None + + def job(self): + logger.debug("waiting for start trigger") + self.start_event.wait() + if self.stop_event.is_set(): + return + logger.debug("starting simulation") + self.connect() + start = time.time() + try: + self.send_hello(self._socket) + while not self.stop_event.is_set(): + if self._socket is None: + logger.error("Connection to GW is lost. Trying to reconnect...") + self.connect() + if time.time() - start > self.interval: + logger.info(f"Sent log") + self.send_log(self._socket) + start = time.time() + self.handle_messages(self._socket) + finally: + self.disconnect() + logger.debug("simulation done") + + +def get_avail_mac_addrs(path, mask="XX:XX:XX:XX:XX:XX"): + _mask = "".join(("[0-9a-fA-F]" if c == "X" else c) for c in mask.upper()) + certs = sorted(os.listdir(path)) + macs = set(cert.split(".")[0] for cert in certs if "crt" in cert and re.match(_mask, cert)) + return list(macs) + + +def update_fd_limit(): + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) + except ValueError: + logger.critical("Failed to update fd limit") + raise + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + logger.warning(f"changed fd limit {soft, hard}") + + +def process(args: Args, mask: str, start_event: multiprocessing.Event, stop_event: multiprocessing.Event): + signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore Ctrl+C in child processes + threading.current_thread().name = mask + logger.info(f"process started") + macs = get_avail_mac_addrs(args.cert_path, mask) + if len(macs) < args.number_of_connections: + logger.warn(f"expected {args.number_of_connections} certificates, but only found {len(macs)} " + f"({mask = })") + update_fd_limit() + + devices = [Device(mac, args.server, args.ca_path, args.msg_interval, args.msg_size, + os.path.join(args.cert_path, f"{mac}.crt"), + os.path.join(args.cert_path, f"{mac}.key"), + start_event, stop_event) + for mac, _ in zip(macs, range(args.number_of_connections))] + threads = [threading.Thread(target=d.job, name=d.mac) for d in devices] + [t.start() for t in threads] + [t.join() for t in threads] + + +def verify_cert_availability(cert_path: str, masks: List[str], count: int): + for mask in masks: + macs = get_avail_mac_addrs(cert_path, mask) + assert len(macs) >= count, \ + f"Simulation requires {count} certificates, but only found {len(macs)}" + + +def trigger_start(evt): + def fn(signum, frame): + logger.info("Signal received, starting simulation...") + evt.set() + return fn + + +def main(args: Args): + verify_cert_availability(args.cert_path, args.masks, args.number_of_connections) + stop_event = multiprocessing.Event() + start_event = multiprocessing.Event() + if not args.wait_for_sig: + start_event.set() + signal.signal(signal.SIGUSR1, trigger_start(start_event)) + processes = [multiprocessing.Process(target=process, args=(args, mask, start_event, stop_event)) + for mask in args.masks] + try: + for p in processes: + p.start() + time.sleep(1) + logger.info(f"Started {len(processes)} processes") + if args.wait_for_sig: + logger.info("Waiting for SIGUSR1...") + while True: + time.sleep(100) + except KeyboardInterrupt: + logger.warn("Stopping all processes...") + stop_event.set() + start_event.set() + [p.join() for p in processes] diff --git a/utils/client_simulator/src/utils.py b/utils/client_simulator/src/utils.py new file mode 100644 index 0000000..1204499 --- /dev/null +++ b/utils/client_simulator/src/utils.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from typing import List +import argparse +import random +import json +import re +import os + + +TEMPLATE_LOCATION = "./data/message_templates.json" + + +@dataclass +class Args: + number_of_connections: int + masks: List[str] + ca_path: str + cert_path: str + msg_size: int + msg_interval: int + wait_for_sig: bool + server_proto: str = "ws" + server_address: str = "localhost" + server_port: int = 50001 + + @property + def server(self): + return f"{self.server_proto}://{self.server_address}:{self.server_port}" + + +def parse_msg_size(input: str) -> int: + match = re.match(r"^(\d+)([kKmM]?)$", input) + if match is None: + raise ValueError(f"Unable to parse message size \"{input}\"") + num, prefix = match.groups() + num = int(num) + if prefix and prefix in "kK": + num *= 1000 + elif prefix and prefix in "mM": + num *= 1000000 + return num + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Used to simulate multiple clients that connect to a single server.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("-s", "--server", metavar="ADDRESS", required=True, + default="ws://localhost:50001", + help="server address") + parser.add_argument("-N", "--number-of-connections", metavar="NUMBER", type=int, + default=1, + help="number of concurrent connections per thread pool") + parser.add_argument("-M", "--mac-mask", metavar="XX:XX:XX:XX:XX:XX", action="append", + default=[], + help="the mask determines what MAC addresses will be used by clients " + "in the thread pool. Specifying multiple masks will increase " + "the number of thread pools used.") + parser.add_argument("-a", "--ca-cert", metavar="CERT", + default="./certs/ca/ca.crt", + help="path to CA certificate") + parser.add_argument("-c", "--client-certs-path", metavar="PATH", + default="./certs/client", + help="path to client certificates directory") + parser.add_argument("-t", "--msg-interval", metavar="SECONDS", type=int, + default=10, + help="time between client messages to gw") + parser.add_argument("-p", "--payload-size", metavar="SIZE", type=str, + default="1k", + help="size of each client message") + parser.add_argument("-w", "--wait-for-signal", action="store_true", + help="wait for SIGUSR1 before running simulation") + + parsed_args = parser.parse_args() + + args = Args(number_of_connections=parsed_args.number_of_connections, + masks=parsed_args.mac_mask, + ca_path=parsed_args.ca_cert, + cert_path=parsed_args.client_certs_path, + msg_interval=parsed_args.msg_interval, + msg_size=parse_msg_size(parsed_args.payload_size), + wait_for_sig=parsed_args.wait_for_signal) + + if len(args.masks) == 0: + args.masks.append("XX:XX:XX:XX:XX:XX") + + # PROTO :// ADDRESS : PORT + match = re.match(r"(?:(wss?)://)?([\d\w\.]+):?(\d+)?", parsed_args.server) + if match is None: + raise ValueError(f"Unable to parse server address {parsed_args.server}") + proto, addr, port = match.groups() + if proto is not None: + args.server_proto = proto + if addr is not None: + args.server_address = addr + if port is not None: + args.server_port = port + + return args + + +def rand_mac(mask="02:xx:xx:xx:xx:xx"): + return ''.join([n.lower().replace('x', f'{random.randint(0, 15):x}') for n in mask]) + + +def get_msg_templates(): + with open(TEMPLATE_LOCATION, "r") as templates: + return json.loads(templates.read()) + + +def gen_certificates(mask: str, count=int): + cwd = os.getcwd() + os.chdir("../cert_generator") + try: + rc = os.system(f"./generate_certs.sh -c {count} -m \"{mask}\" -o") + assert rc == 0, "Generating certificates failed" + finally: + os.chdir(cwd) diff --git a/utils/kafka_producer/.gitignore b/utils/kafka_producer/.gitignore new file mode 100644 index 0000000..a45680f --- /dev/null +++ b/utils/kafka_producer/.gitignore @@ -0,0 +1,7 @@ +bin/ +lib/ +lib64 +include/ +share/ +*.cfg +__pycache__ diff --git a/utils/kafka_producer/data/message_template.json b/utils/kafka_producer/data/message_template.json new file mode 100644 index 0000000..bb1f1fd --- /dev/null +++ b/utils/kafka_producer/data/message_template.json @@ -0,0 +1,33 @@ +{ + "add_group": { + "type": "infrastructure_group_create", + "infra_group_id": "key", + "infra_name": "name", + "infra_shard_id": 0, + "uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff" + }, + "del_group": { + "type": "infrastructure_group_delete", + "infra_group_id": "key", + "uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff" + }, + "add_to_group": { + "type": "infrastructure_group_device_add", + "infra_group_id": "key", + "infra_group_infra_devices": [], + "uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff" + }, + "del_from_group": { + "type": "infrastructure_group_device_del", + "infra_group_id": "key", + "infra_group_infra_devices": [], + "uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff" + }, + "message_device": { + "type": "infrastructure_group_device_message", + "infra_group_id": "key", + "mac": "mac", + "msg": {}, + "uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff" + } +} \ No newline at end of file diff --git a/utils/kafka_producer/main.py b/utils/kafka_producer/main.py new file mode 100755 index 0000000..99ea442 --- /dev/null +++ b/utils/kafka_producer/main.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +from src.cli_parser import parse_args, Args +from src.producer import Producer +from src.log import logger + + +def main(args: Args): + producer = Producer(args.db, args.topic) + if args.add_groups or args.del_groups: + producer.handle_group_creation(args.add_groups, args.del_groups) + if args.assign_to_group or args.remove_from_group: + producer.handle_device_assignment(args.assign_to_group, args.remove_from_group) + if args.message: + producer.handle_device_messages(args.message, args.group_id, args.send_to_macs, + args.count, args.time_to_send_s, args.interval_s) + + +if __name__ == "__main__": + try: + args = parse_args() + main(args) + except KeyboardInterrupt: + logger.warn("exiting...") diff --git a/utils/kafka_producer/requirements.txt b/utils/kafka_producer/requirements.txt new file mode 100644 index 0000000..7aedb58 --- /dev/null +++ b/utils/kafka_producer/requirements.txt @@ -0,0 +1 @@ +kafka-python==2.0.2 diff --git a/utils/kafka_producer/src/__init__.py b/utils/kafka_producer/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/kafka_producer/src/cli_parser.py b/utils/kafka_producer/src/cli_parser.py new file mode 100644 index 0000000..4e64884 --- /dev/null +++ b/utils/kafka_producer/src/cli_parser.py @@ -0,0 +1,116 @@ +from .utils import MacRange, Args +import argparse +import json + + +def time(input: str) -> float: + """ + Expects a string in the format: + *number* + *number*s + *number*m + *number*h + + Returns the number of seconds as an integer. Example: + 100 -> 100 + 100s -> 100 + 5m -> 300 + 2h -> 7200 + """ + if all(char not in input for char in "smh"): + return float(input) + n, t = float(input[:-1]), input[-1:] + if t == "h": + return n * 60 * 60 + if t == "m": + return n * 60 + return n + + +def parse_args(): + parser = argparse.ArgumentParser(description="Creates entries in kafka.") + + parser.add_argument("-g", "--new-group", metavar=("GROUP-ID", "SHARD-ID", "NAME"), + nargs=3, action="append", + help="create a new group") + parser.add_argument("-G", "--rm-group", metavar=("GROUP-ID"), + nargs=1, action="append", + help="delete an existing group") + parser.add_argument("-d", "--assign-to-group", metavar=("GROUP-ID", "MAC-RANGE"), + nargs=2, action="append", + help="add a range of mac addrs to a group") + parser.add_argument("-D", "--remove-from-group", metavar=("GROUP-ID", "MAC-RANGE"), + nargs=2, action="append", + help="remove mac addrs from a group") + parser.add_argument("-T", "--topic", default="CnC", + help="kafka topic (default: \"CnC\")") + parser.add_argument("-s", "--bootstrap-server", metavar="ADDRESS", default="172.20.10.136:9092", + help="kafka address (default: \"172.20.10.136:9092\")") + parser.add_argument("-m", "--send-message", metavar="JSON", type=str, + help="this message will be sent down from the GW to all devices " + "specified in the --send-to-mac") + parser.add_argument("-c", "--send-count", metavar="COUNT", type=int, + help="how many messages will be sent (per mac address, default: 1)") + parser.add_argument("-t", "--send-for", metavar="TIME", type=time, + help="how long to send the messages for") + parser.add_argument("-i", "--send-interval", metavar="INTERVAL", type=time, default="1", + help="time between messages (default: \"1.0s\")") + parser.add_argument("-p", "--send-to-group", metavar="GROUP-ID", type=str) + parser.add_argument("-r", "--send-to-mac", metavar="MAC-RANGE", type=MacRange, + help="range of mac addrs that will be receiving the messages") + + parsed_args = parser.parse_args() + + if parsed_args.send_message is not None and ( + parsed_args.send_to_group is None or + parsed_args.send_to_mac is None + ): + parser.error("--send-message requires --send-to-group and --send-to-mac") + + message = None + if parsed_args.send_message is not None: + try: + message = json.loads(parsed_args.send_message) + except json.JSONDecodeError: + parser.error("--send-message must be in JSON format") + + count = 1 + if parsed_args.send_count is not None: + count = parsed_args.send_count + elif parsed_args.send_for is not None: + count = 0 + + args = Args( + [], [], [], [], + topic=parsed_args.topic, + db=parsed_args.bootstrap_server, + message=message, + count=count, + time_to_send_s=parsed_args.send_for, + interval_s=parsed_args.send_interval, + group_id=parsed_args.send_to_group, + send_to_macs=parsed_args.send_to_mac, + ) + if parsed_args.new_group is not None: + for group, shard, name in parsed_args.new_group: + try: + args.add_groups.append((group, int(shard), name)) + except ValueError: + parser.error(f"--new-group: failed to parse shard id \"{shard}\"") + if parsed_args.rm_group is not None: + for (group,) in parsed_args.rm_group: + args.del_groups.append(group) + if parsed_args.assign_to_group is not None: + try: + for group, mac in parsed_args.assign_to_group: + args.assign_to_group.append((group, MacRange(mac))) + except ValueError: + parser.error(f"--assign-to-group: failed to parse MAC range \"{mac}\"") + if parsed_args.remove_from_group is not None: + try: + for group, mac in parsed_args.remove_from_group: + args.remove_from_group.append((group, MacRange(mac))) + except ValueError: + parser.error(f"--remove-from-group: failed to parse MAC range \"{mac}\"") + + return args diff --git a/utils/kafka_producer/src/log.py b/utils/kafka_producer/src/log.py new file mode 100644 index 0000000..4f70e1c --- /dev/null +++ b/utils/kafka_producer/src/log.py @@ -0,0 +1,37 @@ +import logging + + +class ColoredFormatter(logging.Formatter): + + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + white = "\x1b[37;20m" + cyan = "\x1b[36;20m" + reset = "\x1b[0m" + format = "{asctime}|{levelname}|{threadName}|{funcName}:{lineno}\t{message}" + + FORMATS = { + logging.DEBUG: grey + format + reset, + logging.INFO: cyan + format + reset, + logging.WARNING: yellow + format + reset, + logging.ERROR: red + format + reset, + logging.CRITICAL: bold_red + format + reset, + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, style="{") + return formatter.format(record) + + +logging.basicConfig(level=logging.INFO, datefmt="%H:%M:%S") + +console = logging.StreamHandler() +console.setLevel(logging.NOTSET) +console.setFormatter(ColoredFormatter()) + +logger = logging.getLogger(__name__) +logger.propagate = False +logger.addHandler(console) diff --git a/utils/kafka_producer/src/producer.py b/utils/kafka_producer/src/producer.py new file mode 100644 index 0000000..83332b3 --- /dev/null +++ b/utils/kafka_producer/src/producer.py @@ -0,0 +1,74 @@ +from .utils import Message, MacRange +from .log import logger + +from typing import List, Tuple +import kafka +import time +import uuid +import sys + + +class Producer: + def __init__(self, db: str, topic: str) -> None: + self.db = db + self.conn = None + self.topic = topic + self.message = Message() + + def __enter__(self) -> kafka.KafkaProducer: + return self.connect() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + + def connect(self) -> kafka.KafkaProducer: + if self.conn is None: + self.conn = kafka.KafkaProducer(bootstrap_servers=self.db, client_id="producer") + logger.info("connected to kafka") + else: + logger.info("already connected to kafka") + return self.conn + + def disconnect(self) -> None: + if self.conn is None: + return + self.conn.close() + logger.info("disconnected from kafka") + self.conn = None + + def handle_group_creation(self, create: List[Tuple[str, int, str]], delete: List[str]) -> None: + with self as conn: + for group, shard_id, name in create: + conn.send(self.topic, self.message.group_create(group, shard_id, name), + bytes(group, encoding="utf-8")) + for group in delete: + conn.send(self.topic, self.message.group_delete(group), + bytes(group, encoding="utf-8")) + + def handle_device_assignment(self, add: List[Tuple[str, MacRange]], remove: List[Tuple[str, MacRange]]) -> None: + with self as conn: + for group, mac_range in add: + logger.debug(f"{group = }, {mac_range = }") + conn.send(self.topic, self.message.add_dev_to_group(group, mac_range), + bytes(group, encoding="utf-8")) + for group, mac_range in remove: + conn.send(self.topic, self.message.remove_dev_from_group(group, mac_range), + bytes(group, encoding="utf-8")) + + def handle_device_messages(self, message: dict, group: str, mac_range: MacRange, + count: int, time_s: int, interval_s: int) -> None: + if not time_s: + end = sys.maxsize + else: + end = time.time() + time_s + if not count: + count = sys.maxsize + + with self as conn: + for seq in range(count): + for mac in mac_range: + conn.send(self.topic, self.message.to_device(group, mac, message, seq), + bytes(group, encoding="utf-8")) + #time.sleep(interval_s) + #if time.time() > end: + # break diff --git a/utils/kafka_producer/src/utils.py b/utils/kafka_producer/src/utils.py new file mode 100644 index 0000000..3f8f442 --- /dev/null +++ b/utils/kafka_producer/src/utils.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass +from typing import List, Tuple +from typing import Tuple +import copy +import json +import uuid + + +class MacRange: + """ + Return an object that produces a sequence of MAC addresses from + START (inclusive) to END (inclusive). START and END are exctracted + from the input string if it is in the format + "11:22:AA:BB:00:00-11:22:AA:BB:00:05" (where START=11:22:AA:BB:00:00, + END=11:22:AA:BB:00:05, and the total amount of MACs in the range is 6). + + Examples (all of these are identical): + + 00:00:00:00:XX:XX + 00:00:00:00:00:00-00:00:00:00:FF:FF + 00:00:00:00:00:00^65536 + + Raises ValueError + """ + def __init__(self, input: str = "XX:XX:XX:XX:XX:XX") -> None: + self.__base_as_num, self.__len = self.__parse_input(input.upper()) + self.__idx = 0 + + def __iter__(self): + return self + + def __next__(self) -> str: + if self.__idx >= len(self): + self.__idx = 0 + raise StopIteration() + mac = self.num2mac(self.__base_as_num + self.__idx) + self.__idx += 1 + return mac + + def __len__(self) -> int: + return self.__len + + def __str__(self) -> str: + return f"MacRange[start={self.base}, " \ + f"end={self.num2mac(self.__base_as_num + len(self) - 1)}]" + + def __repr__(self) -> str: + return f"MacRange('{self.base}^{len(self)}')" + + @property + def base(self) -> str: + return self.num2mac(self.__base_as_num) + + @staticmethod + def mac2num(mac: str) -> int: + return int(mac.replace(":", ""), base=16) + + @staticmethod + def num2mac(mac: int) -> str: + hex = f"{mac:012X}" + return ":".join([a+b for a, b in zip(hex[::2], hex[1::2])]) + + def __parse_input(self, input: str) -> Tuple[int, int]: + if "X" in input: + string = f"{input.replace('X', '0')}-{input.replace('X', 'F')}" + else: + string = input + if "-" in string: + start, end = string.split("-") + start, end = self.mac2num(start), self.mac2num(end) + if start > end: + raise ValueError(f"Invalid MAC range {start}-{end}") + return start, end - start + 1 + if "^" in string: + base, count = string.split("^") + return self.mac2num(base), int(count) + return self.mac2num(input), 1 + + +class Message: + TEMPLATE_FILE = "./data/message_template.json" + GROUP_ADD = "add_group" + GROUP_DEL = "del_group" + DEV_TO_GROUP = "add_to_group" + DEV_FROM_GROUP = "del_from_group" + TO_DEVICE = "message_device" + GROUP_ID = "infra_group_id" + GROUP_NAME = "infra_name" + SHARD_ID = "infra_shard_id" + DEV_LIST = "infra_group_infra_devices" + MAC = "mac" + DATA = "msg" + MSG_UUID = "uuid" + + def __init__(self) -> None: + with open(self.TEMPLATE_FILE) as f: + self.templates = json.loads(f.read()) + + def group_create(self, id: str, shard_id: int, name: str) -> bytes: + msg = copy.copy(self.templates[self.GROUP_ADD]) + msg[self.GROUP_ID] = id + msg[self.SHARD_ID] = shard_id + msg[self.GROUP_NAME] = name + msg[self.MSG_UUID] = str(uuid.uuid1()) + return json.dumps(msg).encode('utf-8') + + def group_delete(self, id: str) -> bytes: + msg = copy.copy(self.templates[self.GROUP_DEL]) + msg[self.GROUP_ID] = id + msg[self.MSG_UUID] = str(uuid.uuid1()) + return json.dumps(msg).encode('utf-8') + + def add_dev_to_group(self, id: str, mac_range: MacRange) -> bytes: + msg = copy.copy(self.templates[self.DEV_TO_GROUP]) + msg[self.GROUP_ID] = id + msg[self.DEV_LIST] = list(mac_range) + msg[self.MSG_UUID] = str(uuid.uuid1()) + return json.dumps(msg).encode('utf-8') + + def remove_dev_from_group(self, id: str, mac_range: MacRange) -> bytes: + msg = copy.copy(self.templates[self.DEV_FROM_GROUP]) + msg[self.GROUP_ID] = id + msg[self.DEV_LIST] = list(mac_range) + msg[self.MSG_UUID] = str(uuid.uuid1()) + return json.dumps(msg).encode('utf-8') + + def to_device(self, id: str, mac: str, data, sequence: int = 0): + msg = copy.copy(self.templates[self.TO_DEVICE]) + msg[self.GROUP_ID] = id + msg[self.MAC] = mac + if type(data) is dict: + msg[self.DATA] = data + else: + msg[self.DATA] = {"data": data} + msg[self.MSG_UUID] = str(uuid.uuid1(node=MacRange.mac2num(mac), clock_seq=sequence)) + return json.dumps(msg).encode('utf-8') + + +@dataclass +class Args: + add_groups: List[Tuple[str, int, str]] + del_groups: List[str] + assign_to_group: List[Tuple[str, MacRange]] + remove_from_group: List[Tuple[str, MacRange]] + topic: str + db: str + message: dict + count: int + time_to_send_s: float + interval_s: float + group_id: int + send_to_macs: MacRange