mirror of
https://github.com/optim-enterprises-bv/openlan-cgw.git
synced 2025-10-29 17:32:21 +00:00
36
Cargo.toml
Normal file
36
Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "ucentral-cgw"
|
||||
version = "0.0.1"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde_json = "1.0.85"
|
||||
env_logger = "0.5.0"
|
||||
log = "0.4.20"
|
||||
tokio = { version = "1.34.0", features = ["full"] }
|
||||
tokio-stream = { version = "*", features = ["full"] }
|
||||
tokio-tungstenite = { version = "*", features = ["native-tls"] }
|
||||
tokio-native-tls = "*"
|
||||
tokio-rustls = "*"
|
||||
tokio-postgres = { version = "0.7.10", features = ["with-eui48-1"]}
|
||||
tokio-pg-mapper = "0.2.0"
|
||||
tungstenite = { version = "*"}
|
||||
mio = "0.6.10"
|
||||
native-tls = "*"
|
||||
futures-util = { version = "0.3.0", default-features = false }
|
||||
futures-channel = "0.3.0"
|
||||
futures-executor = { version = "0.3.0", optional = true }
|
||||
futures = "0.3.0"
|
||||
rlimit = "0.10.1"
|
||||
tonic = "0.10.2"
|
||||
prost = "0.12"
|
||||
rdkafka = "0.35.0"
|
||||
clap = { version = "4.4.8", features = ["derive"] }
|
||||
eui48 = "1.1.0"
|
||||
uuid = { version = "1.6.1", features = ["serde"] }
|
||||
redis-async = "0.16.1"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10"
|
||||
prost-build = "0.12"
|
||||
4
build.rs
Normal file
4
build.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tonic_build::compile_protos("src/proto/cgw.proto")?;
|
||||
Ok(())
|
||||
}
|
||||
531
src/cgw_connection_processor.rs
Normal file
531
src/cgw_connection_processor.rs
Normal file
@@ -0,0 +1,531 @@
|
||||
use crate::cgw_connection_server::{
|
||||
CGWConnectionServer,
|
||||
CGWConnectionServerReqMsg,
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{
|
||||
UnboundedReceiver,
|
||||
unbounded_channel,
|
||||
},
|
||||
},
|
||||
net::{
|
||||
TcpStream,
|
||||
},
|
||||
time::{
|
||||
sleep,
|
||||
Duration,
|
||||
Instant,
|
||||
},
|
||||
};
|
||||
use tokio_native_tls::TlsStream;
|
||||
use tokio_tungstenite::{
|
||||
WebSocketStream,
|
||||
tungstenite::{
|
||||
protocol::{
|
||||
Message,
|
||||
},
|
||||
},
|
||||
};
|
||||
use tungstenite::Message::{
|
||||
Close,
|
||||
Text,
|
||||
Ping,
|
||||
};
|
||||
use futures_util::{
|
||||
StreamExt,
|
||||
SinkExt,
|
||||
FutureExt,
|
||||
stream::{
|
||||
SplitSink,
|
||||
SplitStream
|
||||
},
|
||||
};
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
type CGWUcentralJRPCMessage = Map<String, Value>;
|
||||
type SStream = SplitStream<WebSocketStream<TlsStream<TcpStream>>>;
|
||||
type SSink = SplitSink<WebSocketStream<TlsStream<TcpStream>>, Message>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CGWConnectionProcessorReqMsg {
|
||||
// We got green light from server to process this connection on
|
||||
AddNewConnectionAck,
|
||||
AddNewConnectionShouldClose,
|
||||
SinkRequestToDevice(String, String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CGWConnectionState {
|
||||
IsActive,
|
||||
IsForcedToClose,
|
||||
IsDead,
|
||||
IsStale,
|
||||
ClosedGracefully,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Default)]
|
||||
struct CGWEventLogParams {
|
||||
serial: String,
|
||||
log: String,
|
||||
severity: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Default)]
|
||||
struct CGWEventLog {
|
||||
params: CGWEventLogParams,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Default)]
|
||||
struct CGWEventConnectParamsCaps {
|
||||
compatible: String,
|
||||
model: String,
|
||||
platform: String,
|
||||
label_macaddr: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Default)]
|
||||
struct CGWEventConnectParams {
|
||||
serial: String,
|
||||
firmware: String,
|
||||
uuid: u64,
|
||||
capabilities: CGWEventConnectParamsCaps,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Default)]
|
||||
struct CGWEventConnect {
|
||||
params: CGWEventConnectParams,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
enum CGWEvent {
|
||||
Connect(CGWEventConnect),
|
||||
Log(CGWEventLog),
|
||||
Empty,
|
||||
}
|
||||
|
||||
fn cgw_parse_jrpc_event(map: &Map<String, Value>, method: String) -> CGWEvent {
|
||||
if method == "log" {
|
||||
let params = map.get("params").expect("Params are missing");
|
||||
return CGWEvent::Log(CGWEventLog {
|
||||
params: CGWEventLogParams {
|
||||
serial: params["serial"].to_string(),
|
||||
log: params["log"].to_string(),
|
||||
severity: serde_json::from_value(params["severity"].clone()).unwrap(),
|
||||
},
|
||||
});
|
||||
} else if method == "connect" {
|
||||
let params = map.get("params").expect("Params are missing");
|
||||
return CGWEvent::Connect(CGWEventConnect {
|
||||
params: CGWEventConnectParams {
|
||||
serial: params["serial"].to_string(),
|
||||
firmware: params["firmware"].to_string(),
|
||||
uuid: 1,
|
||||
capabilities: CGWEventConnectParamsCaps {
|
||||
compatible: params["capabilities"]["compatible"].to_string(),
|
||||
model: params["capabilities"]["model"].to_string(),
|
||||
platform: params["capabilities"]["platform"].to_string(),
|
||||
label_macaddr: params["capabilities"]["label_macaddr"].to_string(),
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
CGWEvent::Empty
|
||||
}
|
||||
|
||||
async fn cgw_process_jrpc_event(event: &CGWEvent) -> Result<(), String> {
|
||||
// TODO
|
||||
if let CGWEvent::Connect(c) = event {
|
||||
/*
|
||||
info!(
|
||||
"Requesting {} to reboot (immediate request)",
|
||||
c.params.serial
|
||||
);
|
||||
let req = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "reboot",
|
||||
"params": {
|
||||
"serial": c.params.serial,
|
||||
"when": 0
|
||||
},
|
||||
"id": 1
|
||||
});
|
||||
info!("Received connect msg {}", c.params.serial);
|
||||
sender.send(Message::text(req.to_string())).await.ok();
|
||||
*/
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: heavy rework to enum-variant struct-based
|
||||
async fn cgw_process_jrpc_message(message: Message) -> Result<CGWUcentralJRPCMessage, String> {
|
||||
//let rpcmsg: CGWMethodConnect = CGWMethodConnect::default();
|
||||
//serde_json::from_str(method).unwrap();
|
||||
//
|
||||
let msg = if let Ok(s) = message.into_text() {
|
||||
s
|
||||
} else {
|
||||
return Err("Message to string cast failed".to_string());
|
||||
};
|
||||
|
||||
let map: CGWUcentralJRPCMessage = match serde_json::from_str(&msg) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
error!("Failed to parse input json {e}");
|
||||
return Err("Failed to parse input json".to_string());
|
||||
}
|
||||
};
|
||||
//.expect("Failed to parse input json");
|
||||
|
||||
if !map.contains_key("jsonrpc") {
|
||||
warn!("Received malformed JSONRPC msg");
|
||||
return Err("JSONRPC field is missing in message".to_string());
|
||||
}
|
||||
|
||||
if map.contains_key("method") {
|
||||
if !map.contains_key("params") {
|
||||
warn!("Received JRPC <method> without params.");
|
||||
return Err("Received JRPC <method> without params".to_string());
|
||||
}
|
||||
|
||||
// unwrap can panic
|
||||
let method = map["method"].as_str().unwrap();
|
||||
|
||||
let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string());
|
||||
|
||||
match &event {
|
||||
CGWEvent::Log(l) => {
|
||||
debug!(
|
||||
"Received LOG evt from device {}: {}",
|
||||
l.params.serial, l.params.log
|
||||
);
|
||||
}
|
||||
CGWEvent::Connect(c) => {
|
||||
debug!(
|
||||
"Received connect evt from device {}: type {}, fw {}",
|
||||
c.params.serial, c.params.capabilities.platform, c.params.firmware
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
warn!("received not yet implemented method {}", method);
|
||||
return Err(format!("received not yet implemented method {}", method));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = cgw_process_jrpc_event(&event).await {
|
||||
warn!(
|
||||
"Failed to process jrpc event (unmatched) {}",
|
||||
method.to_string()
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
// TODO
|
||||
} else if map.contains_key("result") {
|
||||
info!("Processing <result> JSONRPC msg");
|
||||
info!("{:?}", map);
|
||||
return Err("Result handling is not yet implemented".to_string());
|
||||
}
|
||||
|
||||
/*
|
||||
match map.get_mut("jsonrpc") {
|
||||
Some(value) => info!("Got value {:?}", value),
|
||||
None => info!("Got no value"),
|
||||
}
|
||||
*/
|
||||
/*
|
||||
if let CGWMethod::Connect { ref someint, .. } = &rpcmsg {
|
||||
info!("secondmatch {}", *someint);
|
||||
return Some(rpcmsg);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
*/
|
||||
//return Some(CGWMethodConnect::default());
|
||||
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
pub struct CGWConnectionProcessor {
|
||||
cgw_server: Arc<CGWConnectionServer>,
|
||||
pub serial: Option<String>,
|
||||
pub addr: SocketAddr,
|
||||
pub idx: i64,
|
||||
}
|
||||
|
||||
impl CGWConnectionProcessor {
|
||||
pub fn new(server: Arc<CGWConnectionServer>, conn_idx: i64, addr: SocketAddr) -> Self {
|
||||
let conn_processor : CGWConnectionProcessor = CGWConnectionProcessor {
|
||||
cgw_server: server,
|
||||
serial: None,
|
||||
addr: addr,
|
||||
idx: conn_idx,
|
||||
};
|
||||
|
||||
conn_processor
|
||||
}
|
||||
|
||||
pub async fn start(mut self, tls_stream: TlsStream<TcpStream>) {
|
||||
let ws_stream = tokio_tungstenite::accept_async(tls_stream)
|
||||
.await
|
||||
.expect("error during the websocket handshake occurred");
|
||||
|
||||
let (sink, mut stream) = ws_stream.split();
|
||||
|
||||
// check if we have any pending msgs (we expect connect at this point, protocol-wise)
|
||||
// TODO: rework to ignore any WS-related frames untill we get a connect message,
|
||||
// however there's a caveat: we can miss some events logs etc from underlying device
|
||||
// rework should consider all the options
|
||||
let msg = tokio::select! {
|
||||
_val = stream.next() => {
|
||||
match _val {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
error!("no connect message received from {}, closing connection", self.addr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: configurable duration (upon server creation)
|
||||
_val = sleep(Duration::from_millis(30000)) => {
|
||||
error!("no message received from {}, closing connection", self.addr);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// we have a next() result, but it still may be undelying io error: check for it
|
||||
// break connection if we can't work with underlying ws connection (pror err etc)
|
||||
let message = match msg {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
error!("established connection with device, but failed to receive any messages\n{e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let map = match cgw_process_jrpc_message(message).await {
|
||||
Ok(val) => val,
|
||||
Err(e) => {
|
||||
error!("failed to recv connect message from {}, closing connection", self.addr);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let serial = map["params"]["serial"].as_str().unwrap();
|
||||
self.serial = Some(serial.to_string());
|
||||
|
||||
// TODO: we accepted tls stream and split the WS into RX TX part,
|
||||
// now we have to ASK cgw_connection_server's permission whether
|
||||
// we can proceed on with this underlying connection.
|
||||
// cgw_connection_server has an authorative decision whether
|
||||
// we can proceed.
|
||||
let (mbox_tx, mut mbox_rx) = unbounded_channel::<CGWConnectionProcessorReqMsg>();
|
||||
let msg = CGWConnectionServerReqMsg::AddNewConnection(serial.to_string(), mbox_tx);
|
||||
self.cgw_server.enqueue_mbox_message_to_cgw_server(msg).await;
|
||||
|
||||
let ack = mbox_rx.recv().await;
|
||||
if let Some(m) = ack {
|
||||
match m {
|
||||
CGWConnectionProcessorReqMsg::AddNewConnectionAck => {
|
||||
debug!("websocket connection established: {} {}", self.addr, serial);
|
||||
},
|
||||
_ => panic!("Unexpected response from server, expected ACK/NOT ACK)"),
|
||||
}
|
||||
} else {
|
||||
info!("connection server declined connection, websocket connection {} {} cannot be established",
|
||||
self.addr, serial);
|
||||
return;
|
||||
}
|
||||
|
||||
self.process_connection(stream, sink, mbox_rx).await;
|
||||
}
|
||||
|
||||
async fn process_wss_rx_msg(&self, msg: Result<Message, tungstenite::error::Error>) -> Result<CGWConnectionState, &'static str> {
|
||||
match msg {
|
||||
Ok(msg) => {
|
||||
match msg {
|
||||
Close(_t) => {
|
||||
return Ok(CGWConnectionState::ClosedGracefully);
|
||||
},
|
||||
Text(payload) => {
|
||||
self.cgw_server.enqueue_mbox_message_from_device_to_nb_api_c(self.serial.clone().unwrap(), payload);
|
||||
return Ok(CGWConnectionState::IsActive);
|
||||
},
|
||||
Ping(_t) => {
|
||||
return Ok(CGWConnectionState::IsActive);
|
||||
},
|
||||
_ => {
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
match e {
|
||||
tungstenite::error::Error::AlreadyClosed => {
|
||||
return Err("Underlying connection's been closed");
|
||||
},
|
||||
_ => {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(CGWConnectionState::IsActive)
|
||||
}
|
||||
|
||||
async fn process_sink_mbox_rx_msg(&self,sink: &mut SSink, val: Option<CGWConnectionProcessorReqMsg>) -> Result<CGWConnectionState, &str> {
|
||||
if let Some(msg) = val {
|
||||
let processor_mac = self.serial.clone().unwrap();
|
||||
match msg {
|
||||
CGWConnectionProcessorReqMsg::AddNewConnectionShouldClose => {
|
||||
debug!("MBOX_IN: AddNewConnectionShouldClose, processor (mac:{processor_mac}) (ACK OK)");
|
||||
return Ok(CGWConnectionState::IsForcedToClose);
|
||||
},
|
||||
CGWConnectionProcessorReqMsg::SinkRequestToDevice(mac, pload) => {
|
||||
debug!("MBOX_IN: SinkRequestToDevice, processor (mac:{processor_mac}) req for (mac:{mac}) payload:{pload}");
|
||||
sink.send(Message::text(pload)).await.ok();
|
||||
},
|
||||
_ => panic!("Unexpected message received {:?}", msg),
|
||||
}
|
||||
}
|
||||
Ok(CGWConnectionState::IsActive)
|
||||
}
|
||||
|
||||
async fn process_stale_connection_msg(&self, last_contact: Instant) -> Result<CGWConnectionState, &str> {
|
||||
// TODO: configurable duration (upon server creation)
|
||||
if Instant::now().duration_since(last_contact) > Duration::from_secs(70) {
|
||||
warn!("Closing connection {} (idle for too long, stale)", self.addr);
|
||||
Ok(CGWConnectionState::IsStale)
|
||||
} else {
|
||||
Ok(CGWConnectionState::IsActive)
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_connection(
|
||||
self,
|
||||
mut stream: SStream,
|
||||
mut sink: SSink,
|
||||
mut mbox_rx: UnboundedReceiver<CGWConnectionProcessorReqMsg>) {
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WakeupReason {
|
||||
Unspecified,
|
||||
WSSRxMsg(Result<Message, tungstenite::error::Error>),
|
||||
MboxRx(Option<CGWConnectionProcessorReqMsg>),
|
||||
Stale,
|
||||
}
|
||||
|
||||
let mut last_contact = Instant::now();
|
||||
let mut poll_wss_first = true;
|
||||
|
||||
// Get underlying wakeup reason and do initial parsion, like:
|
||||
// - check if WSS stream has a message or an (recv) error
|
||||
// - check if sinkmbox has a message or an (recv) error
|
||||
// - check if connection's been stale for X time
|
||||
//
|
||||
// TODO: try_next intead of sync .next? could potentially
|
||||
// skyrocket CPU usage.
|
||||
loop {
|
||||
let mut wakeup_reason: WakeupReason = WakeupReason::Unspecified;
|
||||
|
||||
// TODO: refactor
|
||||
// Round-robin selection of stream to process:
|
||||
// first, get single message from WSS, then get a single msg from RX MBOX
|
||||
// It's done to ensure we process WSS and RX MBOX equally with same priority
|
||||
// Also, we have to make sure we don't sleep-wait for any of the streams to
|
||||
// make sure we don't cancel futures that are used for stream processing,
|
||||
// especially TCP stream, which is not cancel-safe
|
||||
if poll_wss_first {
|
||||
if let Some(val) = stream.next().now_or_never() {
|
||||
if let Some(res) = val {
|
||||
if let Ok(msg) = res {
|
||||
wakeup_reason = WakeupReason::WSSRxMsg(Ok(msg));
|
||||
} else if let Err(msg) = res {
|
||||
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(msg));
|
||||
}
|
||||
} else if let None = val {
|
||||
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(tungstenite::error::Error::AlreadyClosed));
|
||||
}
|
||||
} else if let Some(val) = mbox_rx.recv().now_or_never() {
|
||||
wakeup_reason = WakeupReason::MboxRx(val)
|
||||
}
|
||||
|
||||
poll_wss_first = !poll_wss_first;
|
||||
} else {
|
||||
if let Some(val) = mbox_rx.recv().now_or_never() {
|
||||
wakeup_reason = WakeupReason::MboxRx(val)
|
||||
} else if let Some(val) = stream.next().now_or_never() {
|
||||
if let Some(res) = val {
|
||||
if let Ok(msg) = res {
|
||||
wakeup_reason = WakeupReason::WSSRxMsg(Ok(msg));
|
||||
} else if let Err(msg) = res {
|
||||
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(msg));
|
||||
}
|
||||
} else if let None = val {
|
||||
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(tungstenite::error::Error::AlreadyClosed));
|
||||
}
|
||||
}
|
||||
poll_wss_first = !poll_wss_first;
|
||||
}
|
||||
|
||||
// TODO: somehow workaround the sleeping?
|
||||
// Both WSS and RX MBOX are empty: chill for a while
|
||||
if let WakeupReason::Unspecified = wakeup_reason {
|
||||
sleep(Duration::from_millis(1000)).await;
|
||||
wakeup_reason = WakeupReason::Stale;
|
||||
}
|
||||
|
||||
let rc = match wakeup_reason {
|
||||
WakeupReason::WSSRxMsg(res) => {
|
||||
last_contact = Instant::now();
|
||||
self.process_wss_rx_msg(res).await
|
||||
},
|
||||
WakeupReason::MboxRx(mbox_message) => {
|
||||
self.process_sink_mbox_rx_msg(&mut sink, mbox_message).await
|
||||
},
|
||||
WakeupReason::Stale => {
|
||||
self.process_stale_connection_msg(last_contact.clone()).await
|
||||
},
|
||||
_ => {
|
||||
panic!("Failed to get wakeup reason for {} conn", self.addr);
|
||||
},
|
||||
};
|
||||
|
||||
match rc {
|
||||
Err(e) => {
|
||||
warn!("{}", e);
|
||||
break;
|
||||
},
|
||||
Ok(state) => {
|
||||
if let CGWConnectionState::IsActive = state {
|
||||
continue;
|
||||
} else if let CGWConnectionState::IsForcedToClose = state {
|
||||
// Return, because server already closed our mbox tx counterpart (rx),
|
||||
// hence we don't need to send ConnectionClosed message. Server
|
||||
// already knows we're closed.
|
||||
return;
|
||||
} else if let CGWConnectionState::ClosedGracefully = state {
|
||||
warn!("Remote client {} closed connection gracefully", self.serial.clone().unwrap());
|
||||
break;
|
||||
} else if let CGWConnectionState::IsStale = state {
|
||||
warn!("Remote client {} closed due to inactivity", self.serial.clone().unwrap());
|
||||
break;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let mac = self.serial.clone().unwrap();
|
||||
let msg = CGWConnectionServerReqMsg::ConnectionClosed(self.serial.unwrap());
|
||||
self.cgw_server.enqueue_mbox_message_to_cgw_server(msg).await;
|
||||
debug!("MBOX_OUT: ConnectionClosed, processor (mac:{})", mac);
|
||||
}
|
||||
}
|
||||
878
src/cgw_connection_server.rs
Normal file
878
src/cgw_connection_server.rs
Normal file
@@ -0,0 +1,878 @@
|
||||
use crate::{
|
||||
AppArgs,
|
||||
};
|
||||
|
||||
use crate::cgw_nb_api_listener::{
|
||||
CGWNBApiClient,
|
||||
};
|
||||
|
||||
use crate::cgw_connection_processor::{
|
||||
CGWConnectionProcessor,
|
||||
CGWConnectionProcessorReqMsg,
|
||||
};
|
||||
|
||||
use crate::cgw_db_accessor::{
|
||||
CGWDBInfrastructureGroup,
|
||||
};
|
||||
|
||||
use crate::cgw_remote_discovery::{
|
||||
CGWRemoteDiscovery,
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
sync::{
|
||||
RwLock,
|
||||
mpsc::{
|
||||
UnboundedSender,
|
||||
UnboundedReceiver,
|
||||
unbounded_channel,
|
||||
},
|
||||
},
|
||||
time::{
|
||||
sleep,
|
||||
Duration,
|
||||
},
|
||||
net::{
|
||||
TcpStream,
|
||||
},
|
||||
runtime::{
|
||||
Builder,
|
||||
Runtime,
|
||||
},
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use std::sync::atomic::{
|
||||
AtomicUsize,
|
||||
Ordering,
|
||||
};
|
||||
|
||||
use serde_json::{
|
||||
Map,
|
||||
Value,
|
||||
};
|
||||
|
||||
use serde::{
|
||||
Deserialize,
|
||||
Serialize,
|
||||
};
|
||||
|
||||
use uuid::{
|
||||
Uuid,
|
||||
};
|
||||
|
||||
type DeviceSerial = String;
|
||||
type CGWConnmapType = Arc<RwLock<HashMap<String, UnboundedSender<CGWConnectionProcessorReqMsg>>>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CGWConnMap {
|
||||
map: CGWConnmapType,
|
||||
}
|
||||
|
||||
impl CGWConnMap {
|
||||
pub fn new() -> Self {
|
||||
let hash_map: HashMap<String, UnboundedSender<CGWConnectionProcessorReqMsg>> = HashMap::new();
|
||||
let map: Arc<RwLock<HashMap<String, UnboundedSender<CGWConnectionProcessorReqMsg>>>> = Arc::new(RwLock::new(hash_map));
|
||||
let connmap = CGWConnMap {
|
||||
map: map,
|
||||
};
|
||||
connmap
|
||||
}
|
||||
}
|
||||
|
||||
type CGWConnectionServerMboxRx = UnboundedReceiver<CGWConnectionServerReqMsg>;
|
||||
type CGWConnectionServerMboxTx = UnboundedSender<CGWConnectionServerReqMsg>;
|
||||
type CGWConnectionServerNBAPIMboxTx = UnboundedSender<CGWConnectionNBAPIReqMsg>;
|
||||
type CGWConnectionServerNBAPIMboxRx = UnboundedReceiver<CGWConnectionNBAPIReqMsg>;
|
||||
|
||||
// The following pair used internally by server itself to bind
|
||||
// Processor's Req/Res
|
||||
#[derive(Debug)]
|
||||
pub enum CGWConnectionServerReqMsg {
|
||||
// Connection-related messages
|
||||
AddNewConnection(DeviceSerial, UnboundedSender<CGWConnectionProcessorReqMsg>),
|
||||
ConnectionClosed(DeviceSerial),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CGWConnectionNBAPIReqMsgOrigin {
|
||||
FromNBAPI,
|
||||
FromRemoteCGW,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CGWConnectionNBAPIReqMsg {
|
||||
// Enqueue Key, Request, bool = isMessageRelayed
|
||||
EnqueueNewMessageFromNBAPIListener(String, String, CGWConnectionNBAPIReqMsgOrigin),
|
||||
}
|
||||
|
||||
pub struct CGWConnectionServer {
|
||||
local_cgw_id: i32,
|
||||
// CGWConnectionServer write into this mailbox,
|
||||
// and other correspondig Server task Reads RX counterpart
|
||||
mbox_internal_tx: CGWConnectionServerMboxTx,
|
||||
|
||||
// Object that owns underlying mac:connection map
|
||||
connmap: CGWConnMap,
|
||||
|
||||
// Runtime that schedules all the WSS-messages related tasks
|
||||
wss_rx_tx_runtime: Arc<Runtime>,
|
||||
|
||||
// Dedicated runtime (threadpool) for handling internal mbox:
|
||||
// ACK/nACK connection, handle duplicates (clone/open) etc.
|
||||
mbox_internal_runtime_handle: Arc<Runtime>,
|
||||
|
||||
// Dedicated runtime (threadpool) for handling NB-API mbox:
|
||||
// RX NB-API requests, parse, relay (if needed)
|
||||
mbox_nb_api_runtime_handle: Arc<Runtime>,
|
||||
|
||||
// Dedicated runtime (threadpool) for handling NB-API TX routine:
|
||||
// TX NB-API requests (if async send is needed)
|
||||
mbox_nb_api_tx_runtime_handle: Arc<Runtime>,
|
||||
|
||||
// Dedicated runtime (threadpool) for handling (relaying) msgs:
|
||||
// relay-task is spawned inside it, and the produced stream of
|
||||
// remote-cgw messages is being relayed inside this context
|
||||
mbox_relay_msg_runtime_handle: Arc<Runtime>,
|
||||
|
||||
// CGWConnectionServer write into this mailbox,
|
||||
// and other correspondig NB API client is responsible for doing an RX over
|
||||
// receive handle counterpart
|
||||
nb_api_client: Arc<CGWNBApiClient>,
|
||||
|
||||
// Interface used to access all discovered CGW instances
|
||||
// (used for relaying non-local CGW requests from NB-API to target CGW)
|
||||
cgw_remote_discovery: Arc<CGWRemoteDiscovery>,
|
||||
|
||||
// Handler that helps this object to wrap relayed NB-API messages
|
||||
// dedicated for this particular local CGW instance
|
||||
mbox_relayed_messages_handle: CGWConnectionServerNBAPIMboxTx,
|
||||
}
|
||||
|
||||
/*
|
||||
* TODO: split into base struct + enum type field
|
||||
* this requires alot of refactoring in places where msg is used
|
||||
* this is needed to always have uuid of malformed / discarded msg.
|
||||
* e.g.
|
||||
* struct CGWNBApiParsedMsg {
|
||||
* uuid: Uuid,
|
||||
* gid: i32,
|
||||
* type: enum CGWNBApiParsedMsgType,
|
||||
* }
|
||||
*
|
||||
* enum CGWNBApiParsedMsgType {
|
||||
* InfrastructureGroupCreate,
|
||||
* ...
|
||||
* InfrastructureGroupInfraMsg(DeviceSerial, String),
|
||||
* }
|
||||
*/
|
||||
enum CGWNBApiParsedMsg {
|
||||
// TODO: fix kafka_simulator to provide reserved_size
|
||||
InfrastructureGroupCreate(Uuid, i32),
|
||||
InfrastructureGroupDelete(Uuid, i32),
|
||||
InfrastructureGroupInfraAdd(Uuid, i32, Vec<DeviceSerial>),
|
||||
InfrastructureGroupInfraDel(Uuid, i32, Vec<DeviceSerial>),
|
||||
InfrastructureGroupInfraMsg(Uuid, i32, DeviceSerial, String),
|
||||
}
|
||||
|
||||
impl CGWConnectionServer {
|
||||
pub async fn new(app_args: &AppArgs) -> Arc<Self> {
|
||||
let wss_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(app_args.wss_t_num)
|
||||
.thread_name_fn(|| {
|
||||
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
|
||||
format!("cgw-wss-t-{}", id)
|
||||
})
|
||||
.thread_stack_size(3 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
let internal_mbox_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("cgw-mbox")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
let nb_api_mbox_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("cgw-mbox-nbapi")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
let relay_msg_mbox_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("cgw-relay-mbox-nbapi")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
let nb_api_mbox_tx_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("cgw-mbox-nbapi-tx")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
|
||||
let (internal_tx, internal_rx) = unbounded_channel::<CGWConnectionServerReqMsg>();
|
||||
let (nb_api_tx, nb_api_rx) = unbounded_channel::<CGWConnectionNBAPIReqMsg>();
|
||||
|
||||
// Give NB API client a handle where it can do a TX (CLIENT -> CGW_SERVER)
|
||||
// RX is handled in internal_mbox of CGW_Server
|
||||
let nb_api_c = CGWNBApiClient::new(app_args, &nb_api_tx);
|
||||
|
||||
let server = Arc::new(CGWConnectionServer {
|
||||
local_cgw_id: app_args.cgw_id,
|
||||
connmap: CGWConnMap::new(),
|
||||
wss_rx_tx_runtime: wss_runtime_handle,
|
||||
mbox_internal_runtime_handle: internal_mbox_runtime_handle,
|
||||
mbox_nb_api_runtime_handle: nb_api_mbox_runtime_handle,
|
||||
mbox_nb_api_tx_runtime_handle: nb_api_mbox_tx_runtime_handle,
|
||||
mbox_internal_tx: internal_tx,
|
||||
nb_api_client: nb_api_c,
|
||||
cgw_remote_discovery: Arc::new(CGWRemoteDiscovery::new(app_args).await),
|
||||
mbox_relayed_messages_handle: nb_api_tx,
|
||||
mbox_relay_msg_runtime_handle: relay_msg_mbox_runtime_handle,
|
||||
});
|
||||
|
||||
let server_clone = server.clone();
|
||||
// Task for processing mbox_internal_rx, task owns the RX part
|
||||
server.mbox_internal_runtime_handle.spawn(async move {
|
||||
server_clone.process_internal_mbox(internal_rx).await;
|
||||
});
|
||||
|
||||
let server_clone = server.clone();
|
||||
server.mbox_nb_api_runtime_handle.spawn(async move {
|
||||
server_clone.process_internal_nb_api_mbox(nb_api_rx).await;
|
||||
});
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
pub async fn enqueue_mbox_message_to_cgw_server(&self, req: CGWConnectionServerReqMsg) {
|
||||
let _ = self.mbox_internal_tx.send(req);
|
||||
}
|
||||
|
||||
pub fn enqueue_mbox_message_from_device_to_nb_api_c(&self, mac: DeviceSerial, req: String) {
|
||||
// TODO: device (mac) -> group id matching
|
||||
let key = String::from("TBD_INFRA_GROUP");
|
||||
let nb_api_client_clone = self.nb_api_client.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = nb_api_client_clone.enqueue_mbox_message_from_cgw_server(key, req).await;
|
||||
});
|
||||
}
|
||||
|
||||
pub fn enqueue_mbox_message_from_cgw_to_nb_api(&self, gid: i32, req: String) {
|
||||
let nb_api_client_clone = self.nb_api_client.clone();
|
||||
self.mbox_nb_api_tx_runtime_handle.spawn(async move {
|
||||
let _ = nb_api_client_clone.enqueue_mbox_message_from_cgw_server(gid.to_string(), req).await;
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn enqueue_mbox_relayed_message_to_cgw_server(&self, key: String, req: String) {
|
||||
let msg = CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, req, CGWConnectionNBAPIReqMsgOrigin::FromRemoteCGW);
|
||||
let _ = self.mbox_relayed_messages_handle.send(msg);
|
||||
}
|
||||
|
||||
fn parse_nbapi_msg(&self, pload: &String) -> Option<CGWNBApiParsedMsg> {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InfraGroupCreate {
|
||||
r#type: String,
|
||||
infra_group_id: String,
|
||||
infra_name: String,
|
||||
infra_shard_id: i32,
|
||||
uuid: Uuid,
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InfraGroupDelete {
|
||||
r#type: String,
|
||||
infra_group_id: String,
|
||||
uuid: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InfraGroupInfraAdd {
|
||||
r#type: String,
|
||||
infra_group_id: String,
|
||||
infra_group_infra_devices: Vec<String>,
|
||||
uuid: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InfraGroupInfraDel {
|
||||
r#type: String,
|
||||
infra_group_id: String,
|
||||
infra_group_infra_devices: Vec<String>,
|
||||
uuid: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InfraGroupMsgJSON {
|
||||
r#type: String,
|
||||
infra_group_id: String,
|
||||
mac: String,
|
||||
msg: Map<String, Value>,
|
||||
uuid: Uuid,
|
||||
}
|
||||
|
||||
let rc = serde_json::from_str(pload);
|
||||
if let Err(e) = rc {
|
||||
error!("{e}\n{pload}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let map: Map<String, Value> = rc.unwrap();
|
||||
|
||||
let rc = map.get(&String::from("type"));
|
||||
if let None = rc {
|
||||
error!("No msg_type found in\n{pload}");
|
||||
return None;
|
||||
}
|
||||
let rc = rc.unwrap();
|
||||
|
||||
let msg_type = rc.as_str().unwrap();
|
||||
let rc = map.get(&String::from("infra_group_id"));
|
||||
if let None = rc {
|
||||
error!("No infra_group_id found in\n{pload}");
|
||||
return None;
|
||||
}
|
||||
let rc = rc.unwrap();
|
||||
let group_id: i32 = rc.as_str().unwrap().parse().unwrap();
|
||||
|
||||
//debug!("Got msg {msg_type}, infra {group_id}");
|
||||
|
||||
match msg_type {
|
||||
"infrastructure_group_create" => {
|
||||
let json_msg: InfraGroupCreate = serde_json::from_str(&pload).unwrap();
|
||||
//debug!("{:?}", json_msg);
|
||||
return Some(CGWNBApiParsedMsg::InfrastructureGroupCreate(json_msg.uuid, group_id));
|
||||
},
|
||||
"infrastructure_group_delete" => {
|
||||
let json_msg: InfraGroupDelete = serde_json::from_str(&pload).unwrap();
|
||||
//debug!("{:?}", json_msg);
|
||||
return Some(CGWNBApiParsedMsg::InfrastructureGroupDelete(json_msg.uuid, group_id));
|
||||
},
|
||||
"infrastructure_group_device_add" => {
|
||||
let json_msg: InfraGroupInfraAdd = serde_json::from_str(&pload).unwrap();
|
||||
//debug!("{:?}", json_msg);
|
||||
return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraAdd(json_msg.uuid, group_id, json_msg.infra_group_infra_devices));
|
||||
}
|
||||
"infrastructure_group_device_del" => {
|
||||
let json_msg: InfraGroupInfraDel = serde_json::from_str(&pload).unwrap();
|
||||
//debug!("{:?}", json_msg);
|
||||
return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraDel(json_msg.uuid, group_id, json_msg.infra_group_infra_devices));
|
||||
}
|
||||
"infrastructure_group_device_message" => {
|
||||
let json_msg: InfraGroupMsgJSON = serde_json::from_str(&pload).unwrap();
|
||||
//debug!("{:?}", json_msg);
|
||||
return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraMsg(json_msg.uuid, group_id, json_msg.mac, serde_json::to_string(&json_msg.msg).unwrap()));
|
||||
},
|
||||
&_ => {
|
||||
debug!("Unknown type {msg_type} received");
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
async fn process_internal_nb_api_mbox(self: Arc<Self>, mut rx_mbox: CGWConnectionServerNBAPIMboxRx) {
|
||||
debug!("process_nb_api_mbox entry");
|
||||
|
||||
let buf_capacity = 2000;
|
||||
let mut buf: Vec<CGWConnectionNBAPIReqMsg> = Vec::with_capacity(buf_capacity);
|
||||
let mut num_of_msg_read = 0;
|
||||
// As of now, expect at max 100 CGWS remote instances without buffers realloc
|
||||
// This only means that original capacity of all buffers is allocated to <100>,
|
||||
// it can still increase on demand or need automatically (upon insert, push_back etc)
|
||||
let cgw_buf_prealloc_size = 100;
|
||||
|
||||
let mut local_parsed_cgw_msg_buf: Vec<CGWNBApiParsedMsg> = Vec::with_capacity(buf_capacity);
|
||||
|
||||
loop {
|
||||
if num_of_msg_read < buf_capacity {
|
||||
// Try to recv_many, but don't sleep too much
|
||||
// in case if no messaged pending and we have
|
||||
// TODO: rework to pull model (pull on demand),
|
||||
// compared to curr impl: push model (nb api listener forcefully
|
||||
// pushed all fetched data from kafka).
|
||||
// Currently recv_many may sleep if previous read >= 1,
|
||||
// but no new messages pending
|
||||
//
|
||||
// It's also possible that this logic staggers the processing,
|
||||
// in case when every new message is received <=9 ms for example:
|
||||
// Single message received, waiting for new up to 10 ms.
|
||||
// New received on 9th ms. Repeat.
|
||||
// And this could repeat up untill buffer is full, or no new messages
|
||||
// appear on the 10ms scale.
|
||||
// Highly unlikly scenario, but still possible.
|
||||
let rd_num = tokio::select! {
|
||||
v = rx_mbox.recv_many(&mut buf, buf_capacity - num_of_msg_read) => {
|
||||
v
|
||||
}
|
||||
v = sleep(Duration::from_millis(10)) => {
|
||||
0
|
||||
}
|
||||
};
|
||||
num_of_msg_read += rd_num;
|
||||
|
||||
// We read some messages, try to continue and read more
|
||||
// If none read - break from recv, process all buffers that've
|
||||
// been filled-up so far (both local and remote).
|
||||
// Upon done - repeat.
|
||||
if rd_num >= 1 {
|
||||
continue;
|
||||
} else {
|
||||
if num_of_msg_read == 0 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Received {num_of_msg_read} messages from NB API, processing...");
|
||||
|
||||
// We rely on this map only for a single iteration of received messages:
|
||||
// say, we receive 10 messages but 20 in queue, this means that gid->cgw_id
|
||||
// cache is clear at first, the filled up when processing first 10 messages,
|
||||
// the clear/reassigned again for next 10 msgs (10->20).
|
||||
// This is done to ensure that we don't fallback for redis too much,
|
||||
// but still somewhat fully rely on it.
|
||||
//
|
||||
self.cgw_remote_discovery.sync_gid_to_cgw_map().await;
|
||||
|
||||
local_parsed_cgw_msg_buf.clear();
|
||||
|
||||
// TODO: rework to avoid re-allocating these buffers on each loop iteration
|
||||
// (get mut slice of vec / clear when done?)
|
||||
let mut relayed_cgw_msg_buf: Vec<(i32, CGWConnectionNBAPIReqMsg)> = Vec::with_capacity(num_of_msg_read + 1);
|
||||
let mut local_cgw_msg_buf: Vec<CGWConnectionNBAPIReqMsg> = Vec::with_capacity(num_of_msg_read + 1);
|
||||
|
||||
while ! buf.is_empty() {
|
||||
let msg = buf.remove(0);
|
||||
|
||||
if let CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin) = msg {
|
||||
let gid_numeric = match key.parse::<i32>() {
|
||||
Err(e) => {
|
||||
warn!("Invalid KEY received from KAFKA bus message, ignoring\n{e}");
|
||||
continue;
|
||||
},
|
||||
Ok(v) => v
|
||||
};
|
||||
|
||||
let parsed_msg = match self.parse_nbapi_msg(&payload) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
warn!("Failed to parse recv msg with key {key}, discarded");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// The one shard that received add/del is responsible for
|
||||
// handling it at place.
|
||||
// Any other msg is either relayed / handled locally later.
|
||||
// The reason for this is following: current shard is responsible
|
||||
// for assignment of GID to shard, thus it has to make
|
||||
// assignment as soon as possible to deduce relaying action in
|
||||
// the following message pool that is being handled.
|
||||
// Same for delete.
|
||||
if let CGWNBApiParsedMsg::InfrastructureGroupCreate(uuid, gid) = parsed_msg {
|
||||
// DB stuff - create group for remote shards to be aware of change
|
||||
let group = CGWDBInfrastructureGroup {
|
||||
id: gid,
|
||||
reserved_size: 1000i32,
|
||||
actual_size: 0i32,
|
||||
};
|
||||
match self.cgw_remote_discovery.create_infra_group(&group).await {
|
||||
Ok(dst_cgw_id) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Group has been created successfully gid {gid}"));
|
||||
},
|
||||
Err(e) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to create new group (duplicate create?), gid {gid}"));
|
||||
warn!("Create group gid {gid} received, but it already exists");
|
||||
}
|
||||
}
|
||||
// This type of msg is handled in place, not added to buf
|
||||
// for later processing.
|
||||
continue;
|
||||
} else if let CGWNBApiParsedMsg::InfrastructureGroupDelete(uuid, gid) = parsed_msg {
|
||||
match self.cgw_remote_discovery.destroy_infra_group(gid).await {
|
||||
Ok(()) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Group has been destroyed successfully gid {gid}, uuid {uuid}"));
|
||||
},
|
||||
Err(e) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to destroy group (doesn't exist?), gid {gid}, uuid {uuid}"));
|
||||
warn!("Destroy group gid {gid} received, but it does not exist");
|
||||
}
|
||||
}
|
||||
// This type of msg is handled in place, not added to buf
|
||||
// for later processing.
|
||||
continue;
|
||||
}
|
||||
|
||||
// We received NB API msg, check origin:
|
||||
// If it's a relayed message, we must not relay it further
|
||||
// If msg originated from Kafka originally, it's safe to relay it (if needed)
|
||||
if let CGWConnectionNBAPIReqMsgOrigin::FromRemoteCGW = origin {
|
||||
local_parsed_cgw_msg_buf.push(parsed_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
match self.cgw_remote_discovery.get_infra_group_owner_id(key.parse::<i32>().unwrap()).await {
|
||||
Some(dst_cgw_id) => {
|
||||
if dst_cgw_id == self.local_cgw_id {
|
||||
local_cgw_msg_buf.push(CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin));
|
||||
} else {
|
||||
relayed_cgw_msg_buf.push((dst_cgw_id, CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin)));
|
||||
}
|
||||
},
|
||||
None => {
|
||||
warn!("Received msg for gid {gid_numeric}, while this group is unassigned to any of CGWs: rejecting");
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid_numeric,
|
||||
format!("Received message for unknown group {gid_numeric} - unassigned?"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let discovery_clone = self.cgw_remote_discovery.clone();
|
||||
let self_clone = self.clone();
|
||||
|
||||
// Future to Handle (relay) messages for remote CGW
|
||||
let relay_task_hdl = self.mbox_relay_msg_runtime_handle.spawn(async move {
|
||||
let mut remote_cgws_map: HashMap<String, (i32, Vec<(String, String)>)> = HashMap::with_capacity(cgw_buf_prealloc_size);
|
||||
|
||||
while ! relayed_cgw_msg_buf.is_empty() {
|
||||
let msg = relayed_cgw_msg_buf.remove(0);
|
||||
if let (dst_cgw_id, CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, _origin)) = msg {
|
||||
debug!("Received MSG for remote CGW k:{}, local id {} relaying msg to remote...", key, self_clone.local_cgw_id);
|
||||
if let Some(v) = remote_cgws_map.get_mut(&key) {
|
||||
v.1.push((key, payload));
|
||||
} else {
|
||||
let mut tmp_vec: Vec<(String, String)> = Vec::with_capacity(num_of_msg_read);
|
||||
tmp_vec.push((key.clone(), payload));
|
||||
remote_cgws_map.insert(key, (dst_cgw_id, tmp_vec));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for value in remote_cgws_map.into_values() {
|
||||
let discovery_clone = discovery_clone.clone();
|
||||
let cgw_id = value.0;
|
||||
let msg_stream = value.1;
|
||||
let self_clone = self_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(()) = discovery_clone.relay_request_stream_to_remote_cgw(cgw_id, msg_stream).await {
|
||||
self_clone.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
-1,
|
||||
format!("Failed to relay MSG stream to remote CGW{cgw_id}, UUIDs: not implemented (TODO)"));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Handle messages for local CGW
|
||||
// Parse all messages first, then process
|
||||
// TODO: try to parallelize at least parsing of msg:
|
||||
// iterate each msg, get index, spawn task that would
|
||||
// write indexed parsed msg into output parsed msg buf.
|
||||
let connmap_clone = self.connmap.map.clone();
|
||||
while ! local_cgw_msg_buf.is_empty() {
|
||||
let msg = local_cgw_msg_buf.remove(0);
|
||||
if let CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, _origin) = msg {
|
||||
let gid_numeric: i32 = key.parse::<i32>().unwrap();
|
||||
debug!("Received message for local CGW k:{key}, local id {}", self.local_cgw_id);
|
||||
let msg = self.parse_nbapi_msg(&payload);
|
||||
if let None = msg {
|
||||
error!("Failed to parse msg from NBAPI (malformed?)");
|
||||
continue;
|
||||
}
|
||||
|
||||
match msg.unwrap() {
|
||||
CGWNBApiParsedMsg::InfrastructureGroupInfraAdd(uuid, gid, mac_list) => {
|
||||
if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await {
|
||||
warn!("Unexpected: tried to add infra list to nonexisting group (gid {gid}, uuid {uuid}");
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to insert MACs from infra list, gid {gid}, uuid {uuid}: group does not exist."));
|
||||
}
|
||||
|
||||
match self.cgw_remote_discovery.create_ifras_list(gid, mac_list).await {
|
||||
Ok(()) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Infra list has been created successfully gid {gid}, uuid {uuid}"));
|
||||
},
|
||||
Err(macs) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to insert few MACs from infra list, gid {gid}, uuid {uuid}; List of failed MACs:{}",
|
||||
macs.iter().map(|x| x.to_string() + ",").collect::<String>()));
|
||||
warn!("Failed to create few MACs from infras list (partial create)");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
},
|
||||
CGWNBApiParsedMsg::InfrastructureGroupInfraDel(uuid, gid, mac_list) => {
|
||||
if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await {
|
||||
warn!("Unexpected: tried to delete infra list from nonexisting group (gid {gid}, uuid {uuid}");
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to delete MACs from infra list, gid {gid}, uuid {uuid}: group does not exist."));
|
||||
}
|
||||
|
||||
match self.cgw_remote_discovery.destroy_ifras_list(gid, mac_list).await {
|
||||
Ok(()) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Infra list has been destroyed successfully gid {gid}, uuid {uuid}"));
|
||||
},
|
||||
Err(macs) => {
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to destroy few MACs from infra list (not created?), gid {gid}, uuid {uuid}; List of failed MACs:{}",
|
||||
macs.iter().map(|x| x.to_string() + ",").collect::<String>()));
|
||||
warn!("Failed to destroy few MACs from infras list (partial delete)");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
},
|
||||
CGWNBApiParsedMsg::InfrastructureGroupInfraMsg(uuid, gid, mac, msg) => {
|
||||
if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await {
|
||||
warn!("Unexpected: tried to sink down msg to device of nonexisting group (gid {gid}, uuid {uuid}");
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to sink down msg to device of nonexisting group, gid {gid}, uuid {uuid}: group does not exist."));
|
||||
}
|
||||
|
||||
debug!("Sending msg to device {mac}");
|
||||
let rd_lock = connmap_clone.read().await;
|
||||
let rc = rd_lock.get(&mac);
|
||||
if let None = rc {
|
||||
error!("Cannot find suitable connection for {mac}");
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to send msg (device not connected?), gid {gid}, uuid {uuid}"));
|
||||
continue;
|
||||
}
|
||||
|
||||
let proc_mbox_tx = rc.unwrap();
|
||||
if let Err(e) = proc_mbox_tx.send(CGWConnectionProcessorReqMsg::SinkRequestToDevice(mac, msg)) {
|
||||
error!("Failed to send message to remote device (msg uuid({uuid}))");
|
||||
self.enqueue_mbox_message_from_cgw_to_nb_api(
|
||||
gid,
|
||||
format!("Failed to send msg, gid {gid}, uuid {uuid}"));
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
debug!("Received unimplemented/unexpected group create/del msg, ignoring");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do not proceed parsing local / remote msgs untill previous relaying has been
|
||||
// finished
|
||||
tokio::join!(relay_task_hdl);
|
||||
|
||||
buf.clear();
|
||||
num_of_msg_read = 0;
|
||||
}
|
||||
panic!("RX or TX counterpart of nb_api channel part destroyed, while processing task is still active");
|
||||
}
|
||||
|
||||
async fn process_internal_mbox(self: Arc<Self>, mut rx_mbox: CGWConnectionServerMboxRx) {
|
||||
debug!("process_internal_mbox entry");
|
||||
|
||||
let buf_capacity = 1000;
|
||||
let mut buf: Vec<CGWConnectionServerReqMsg> = Vec::with_capacity(buf_capacity);
|
||||
let mut num_of_msg_read = 0;
|
||||
|
||||
loop {
|
||||
if num_of_msg_read < buf_capacity {
|
||||
// Try to recv_many, but don't sleep too much
|
||||
// in case if no messaged pending and we have
|
||||
// TODO: rework?
|
||||
// Currently recv_many may sleep if previous read >= 1,
|
||||
// but no new messages pending
|
||||
let rd_num = tokio::select! {
|
||||
v = rx_mbox.recv_many(&mut buf, buf_capacity - num_of_msg_read) => {
|
||||
v
|
||||
}
|
||||
v = sleep(Duration::from_millis(10)) => {
|
||||
0
|
||||
}
|
||||
};
|
||||
num_of_msg_read += rd_num;
|
||||
|
||||
// We read some messages, try to continue and read more
|
||||
// If none read - break from recv, process all buffers that've
|
||||
// been filled-up so far (both local and remote).
|
||||
// Upon done - repeat.
|
||||
if rd_num >= 1 {
|
||||
if num_of_msg_read < 100 {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if num_of_msg_read == 0 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut connmap_w_lock = self.connmap.map.write().await;
|
||||
|
||||
while ! buf.is_empty() {
|
||||
let msg = buf.remove(0);
|
||||
|
||||
if let CGWConnectionServerReqMsg::AddNewConnection(serial, conn_processor_mbox_tx) = msg {
|
||||
// if connection is unique: simply insert new conn
|
||||
//
|
||||
// if duplicate exists: notify server about such incident.
|
||||
// it's up to server to notify underlying task that it should
|
||||
// drop the connection.
|
||||
// from now on simply insert new connection into hashmap and proceed on
|
||||
// processing it.
|
||||
let serial_clone: DeviceSerial = serial.clone();
|
||||
if let Some(c) = connmap_w_lock.remove(&serial_clone) {
|
||||
tokio::spawn(async move {
|
||||
warn!("Duplicate connection (mac:{}) detected, closing OLD connection in favor of NEW", serial_clone);
|
||||
let msg: CGWConnectionProcessorReqMsg = CGWConnectionProcessorReqMsg::AddNewConnectionShouldClose;
|
||||
c.send(msg).unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
// clone a sender handle, as we still have to send ACK back using underlying
|
||||
// tx mbox handle
|
||||
let conn_processor_mbox_tx_clone = conn_processor_mbox_tx.clone();
|
||||
|
||||
info!("connmap: connection with {} established, new num_of_connections:{}", serial, connmap_w_lock.len() + 1);
|
||||
connmap_w_lock.insert(serial, conn_processor_mbox_tx);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let msg: CGWConnectionProcessorReqMsg = CGWConnectionProcessorReqMsg::AddNewConnectionAck;
|
||||
conn_processor_mbox_tx_clone.send(msg).unwrap();
|
||||
});
|
||||
} else if let CGWConnectionServerReqMsg::ConnectionClosed(serial) = msg {
|
||||
info!("connmap: removed {} serial from connmap, new num_of_connections:{}", serial, connmap_w_lock.len() - 1);
|
||||
connmap_w_lock.remove(&serial);
|
||||
}
|
||||
}
|
||||
|
||||
buf.clear();
|
||||
num_of_msg_read = 0;
|
||||
}
|
||||
|
||||
panic!("RX or TX counterpart of mbox_internal channel part destroyed, while processing task is still active");
|
||||
}
|
||||
|
||||
pub async fn ack_connection(self: Arc<Self>, socket: TcpStream, tls_acceptor: tokio_native_tls::TlsAcceptor, addr: SocketAddr, conn_idx: i64) {
|
||||
// Only ACK connection. We will either drop it or accept it once processor starts
|
||||
// (we'll handle it via "mailbox" notify handle in process_internal_mbox)
|
||||
let server_clone = self.clone();
|
||||
|
||||
self.wss_rx_tx_runtime.spawn(async move {
|
||||
// Accept the TLS connection.
|
||||
let tls_stream = match tls_acceptor.accept(socket).await {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
warn!("Err {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let conn_processor = CGWConnectionProcessor::new(server_clone, conn_idx, addr);
|
||||
conn_processor.start(tls_stream).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn get_connect_json_msg() -> &'static str {
|
||||
r#"
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "connect",
|
||||
"params": {
|
||||
"serial": "00000000ca4b",
|
||||
"firmware": "SONiC-OS-4.1.0_vs_daily_221213_1931_422-campus",
|
||||
"uuid": 1,
|
||||
"capabilities": {
|
||||
"compatible": "+++x86_64-kvm_x86_64-r0",
|
||||
"model": "DellEMC-S5248f-P-25G-DPB",
|
||||
"platform": "switch",
|
||||
"label_macaddr": "00:00:00:00:ca:4b"
|
||||
}
|
||||
}
|
||||
}"#
|
||||
}
|
||||
|
||||
fn get_log_json_msg() -> &'static str {
|
||||
r#"
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "log",
|
||||
"params": {
|
||||
"serial": "00000000ca4b",
|
||||
"log": "uc-client: connection error: Unable to connect",
|
||||
"severity": 3
|
||||
}
|
||||
}"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_parse_connect_event() {
|
||||
let msg = get_connect_json_msg();
|
||||
|
||||
let map: Map<String, Value> =
|
||||
serde_json::from_str(msg).expect("Failed to parse input json");
|
||||
let method = map["method"].as_str().unwrap();
|
||||
let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string());
|
||||
|
||||
match event {
|
||||
CGWEvent::Connect(_) => {
|
||||
assert!(true);
|
||||
}
|
||||
_ => {
|
||||
assert!(false, "Expected event to be of <Connect> type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_parse_log_event() {
|
||||
let msg = get_log_json_msg();
|
||||
|
||||
let map: Map<String, Value> =
|
||||
serde_json::from_str(msg).expect("Failed to parse input json");
|
||||
let method = map["method"].as_str().unwrap();
|
||||
let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string());
|
||||
|
||||
match event {
|
||||
CGWEvent::Log(_) => {
|
||||
assert!(true);
|
||||
}
|
||||
_ => {
|
||||
assert!(false, "Expected event to be of <Log> type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
205
src/cgw_db_accessor.rs
Normal file
205
src/cgw_db_accessor.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use crate::{
|
||||
AppArgs,
|
||||
};
|
||||
|
||||
use eui48::{
|
||||
MacAddress,
|
||||
};
|
||||
|
||||
use tokio_postgres::{
|
||||
Client,
|
||||
NoTls,
|
||||
row::{
|
||||
Row,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CGWDBInfra {
|
||||
pub mac: MacAddress,
|
||||
pub infra_group_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CGWDBInfrastructureGroup {
|
||||
pub id: i32,
|
||||
pub reserved_size: i32,
|
||||
pub actual_size: i32,
|
||||
}
|
||||
|
||||
impl From<Row> for CGWDBInfra {
|
||||
fn from(row: Row) -> Self {
|
||||
let serial: MacAddress = row.get("mac");
|
||||
let gid: i32 = row.get("infra_group_id");
|
||||
Self {
|
||||
mac: serial,
|
||||
infra_group_id: gid,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Row> for CGWDBInfrastructureGroup {
|
||||
fn from(row: Row) -> Self {
|
||||
let infra_id: i32 = row.get("id");
|
||||
let res_size: i32 = row.get("reserved_size");
|
||||
let act_size: i32 = row.get("actual_size");
|
||||
Self {
|
||||
id: infra_id,
|
||||
reserved_size: res_size,
|
||||
actual_size: act_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CGWDBAccessor {
|
||||
cl: Client,
|
||||
}
|
||||
|
||||
impl CGWDBAccessor {
|
||||
pub async fn new(app_args: &AppArgs) -> Self {
|
||||
let conn_str = format!("host={host} port={port} user={user} dbname={db} password={pass}",
|
||||
host = app_args.db_ip,
|
||||
port = app_args.db_port,
|
||||
user = app_args.db_username,
|
||||
db = app_args.db_name,
|
||||
pass = app_args.db_password);
|
||||
debug!("Trying to connect to remote db ({}:{})...",
|
||||
app_args.db_ip.to_string(),
|
||||
app_args.db_port.to_string());
|
||||
debug!("Conn args {conn_str}");
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(&conn_str, NoTls).await.unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
info!("Connected to remote DB");
|
||||
|
||||
CGWDBAccessor {
|
||||
cl: client,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* INFRA_GROUP db API uses the following table decl
|
||||
* TODO: id = int, not varchar; requires kafka simulator changes
|
||||
CREATE TABLE infrastructure_groups (
|
||||
id VARCHAR (50) PRIMARY KEY,
|
||||
reserved_size INT,
|
||||
actual_size INT
|
||||
);
|
||||
*
|
||||
*/
|
||||
|
||||
pub async fn insert_new_infra_group(&self, g: &CGWDBInfrastructureGroup) -> Result<(), &'static str> {
|
||||
let q = self.cl.prepare("INSERT INTO infrastructure_groups (id, reserved_size, actual_size) VALUES ($1, $2, $3)").await.unwrap();
|
||||
let res = self.cl.execute(&q, &[&g.id, &g.reserved_size, &g.actual_size]).await;
|
||||
|
||||
match res {
|
||||
Ok(n) => return Ok(()),
|
||||
Err(e) => {
|
||||
error!("Failed to insert a new infra group {}: {:?}", g.id, e.to_string());
|
||||
return Err("Insert new infra group failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_infra_group(&self, gid: i32) -> Result<(), &'static str> {
|
||||
// TODO: query-base approach instead of static string
|
||||
let req = self.cl.prepare("DELETE FROM infrastructure_groups WHERE id = $1").await.unwrap();
|
||||
let res = self.cl.execute(&req, &[&gid]).await;
|
||||
|
||||
match res {
|
||||
Ok(n) => {
|
||||
if n > 0 {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err("Failed to delete group from DB: gid does not exist");
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to delete an infra group {gid}: {:?}", e.to_string());
|
||||
return Err("Delete infra group failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all_infra_groups(&self) -> Option<Vec<CGWDBInfrastructureGroup>> {
|
||||
let mut list: Vec<CGWDBInfrastructureGroup> = Vec::with_capacity(1000);
|
||||
|
||||
let res = self.cl.query("SELECT * from infrastructure_groups", &[]).await;
|
||||
|
||||
match res {
|
||||
Ok(r) => {
|
||||
for x in r {
|
||||
let infra_group = CGWDBInfrastructureGroup::from(x);
|
||||
list.push(infra_group);
|
||||
}
|
||||
return Some(list);
|
||||
}
|
||||
Err(e) => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_infra_group(&self, gid: i32) -> Option<CGWDBInfrastructureGroup> {
|
||||
let q = self.cl.prepare("SELECT * from infrastructure_groups WHERE id = $1").await.unwrap();
|
||||
let row = self.cl.query_one(&q, &[&gid]).await;
|
||||
|
||||
match row {
|
||||
Ok(r) => return Some(CGWDBInfrastructureGroup::from(r)),
|
||||
Err(e) => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* INFRA db API uses the following table decl
|
||||
CREATE TABLE infras (
|
||||
mac MACADDR PRIMARY KEY,
|
||||
infra_group_id INT,
|
||||
FOREIGN KEY(infra_group_id) REFERENCES infrastructure_groups(id) ON DELETE CASCADE
|
||||
);
|
||||
*/
|
||||
|
||||
pub async fn insert_new_infra(&self, infra: &CGWDBInfra) -> Result<(), &'static str> {
|
||||
let q = self.cl.prepare("INSERT INTO infras (mac, infra_group_id) VALUES ($1, $2)").await.unwrap();
|
||||
let res = self.cl.execute(
|
||||
&q,
|
||||
&[&infra.mac, &infra.infra_group_id]).await;
|
||||
|
||||
match res {
|
||||
Ok(n) => return Ok(()),
|
||||
Err(e) => {
|
||||
error!("Failed to insert a new infra: {:?}", e.to_string());
|
||||
return Err("Insert new infra failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_infra(&self, serial: MacAddress) -> Result<(), &'static str> {
|
||||
let q = self.cl.prepare("DELETE FROM infras WHERE mac = $1").await.unwrap();
|
||||
let res = self.cl.execute(
|
||||
&q,
|
||||
&[&serial]).await;
|
||||
|
||||
match res {
|
||||
Ok(n) => {
|
||||
if n > 0 {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err("Failed to delete infra from DB: MAC does not exist");
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to delete infra: {:?}", e.to_string());
|
||||
return Err("Delete infra failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
260
src/cgw_nb_api_listener.rs
Normal file
260
src/cgw_nb_api_listener.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
use crate::{
|
||||
AppArgs,
|
||||
};
|
||||
|
||||
use crate::cgw_connection_server::{
|
||||
CGWConnectionNBAPIReqMsg,
|
||||
CGWConnectionNBAPIReqMsgOrigin,
|
||||
};
|
||||
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc::{
|
||||
UnboundedSender,
|
||||
},
|
||||
},
|
||||
time::{
|
||||
Duration,
|
||||
},
|
||||
runtime::{
|
||||
Builder,
|
||||
Runtime,
|
||||
},
|
||||
};
|
||||
use futures::stream::{TryStreamExt};
|
||||
use rdkafka::client::ClientContext;
|
||||
use rdkafka::config::{ClientConfig, RDKafkaLogLevel};
|
||||
use rdkafka::{
|
||||
consumer::{
|
||||
Consumer,
|
||||
ConsumerContext,
|
||||
Rebalance,
|
||||
stream_consumer::{
|
||||
StreamConsumer,
|
||||
},
|
||||
},
|
||||
producer::{
|
||||
FutureProducer,
|
||||
FutureRecord,
|
||||
},
|
||||
};
|
||||
use rdkafka::error::KafkaResult;
|
||||
use rdkafka::message::{Message};
|
||||
use rdkafka::topic_partition_list::TopicPartitionList;
|
||||
|
||||
type CGWConnectionServerMboxTx = UnboundedSender<CGWConnectionNBAPIReqMsg>;
|
||||
type CGWCNCConsumerType = StreamConsumer<CustomContext>;
|
||||
type CGWCNCProducerType = FutureProducer;
|
||||
|
||||
struct CustomContext;
|
||||
impl ClientContext for CustomContext {}
|
||||
|
||||
impl ConsumerContext for CustomContext {
|
||||
fn pre_rebalance(&self, rebalance: &Rebalance) {
|
||||
let mut part_list = String::new();
|
||||
if let rdkafka::consumer::Rebalance::Assign(partitions) = rebalance {
|
||||
for x in partitions.elements() {
|
||||
part_list += &(x.partition().to_string() + " ");
|
||||
}
|
||||
debug!("pre_rebalance callback, assigned partition(s): {part_list}");
|
||||
}
|
||||
|
||||
part_list.clear();
|
||||
|
||||
if let rdkafka::consumer::Rebalance::Revoke(partitions) = rebalance {
|
||||
for x in partitions.elements() {
|
||||
part_list += &(x.partition().to_string() + " ");
|
||||
}
|
||||
debug!("pre_rebalance callback, revoked partition(s): {part_list}");
|
||||
}
|
||||
}
|
||||
|
||||
fn post_rebalance(&self, rebalance: &Rebalance) {
|
||||
let mut part_list = String::new();
|
||||
|
||||
if let rdkafka::consumer::Rebalance::Assign(partitions) = rebalance {
|
||||
for x in partitions.elements() {
|
||||
part_list += &(x.partition().to_string() + " ");
|
||||
}
|
||||
debug!("post_rebalance callback, assigned partition(s): {part_list}");
|
||||
}
|
||||
|
||||
part_list.clear();
|
||||
|
||||
if let rdkafka::consumer::Rebalance::Revoke(partitions) = rebalance {
|
||||
for x in partitions.elements() {
|
||||
part_list += &(x.partition().to_string() + " ");
|
||||
}
|
||||
debug!("post_rebalance callback, revoked partition(s): {part_list}");
|
||||
}
|
||||
}
|
||||
|
||||
fn commit_callback(&self, result: KafkaResult<()>, _offsets: &TopicPartitionList) {
|
||||
let mut part_list = String::new();
|
||||
for x in _offsets.elements() {
|
||||
part_list += &(x.partition().to_string() + " ");
|
||||
}
|
||||
debug!("commit_callback callback, partition(s): {part_list}");
|
||||
debug!("Consumer callback: commited offset");
|
||||
}
|
||||
}
|
||||
|
||||
static GROUP_ID: &'static str = "CGW";
|
||||
const CONSUMER_TOPICS: [&'static str; 1] = ["CnC"];
|
||||
const PRODUCER_TOPICS: &'static str = "CnC_Res";
|
||||
|
||||
struct CGWCNCProducer {
|
||||
p: CGWCNCProducerType,
|
||||
}
|
||||
|
||||
struct CGWCNCConsumer {
|
||||
c: CGWCNCConsumerType,
|
||||
}
|
||||
|
||||
impl CGWCNCConsumer {
|
||||
pub fn new(app_args: &AppArgs) -> Self {
|
||||
let consum: CGWCNCConsumerType = Self::create_consumer(app_args);
|
||||
CGWCNCConsumer {
|
||||
c: consum,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_consumer(app_args: &AppArgs) -> CGWCNCConsumerType {
|
||||
let context = CustomContext;
|
||||
|
||||
debug!("Trying to connect to kafka broker ({}:{})...",
|
||||
app_args.kafka_ip.to_string(),
|
||||
app_args.kafka_port.to_string());
|
||||
|
||||
let consumer: CGWCNCConsumerType = ClientConfig::new()
|
||||
.set("group.id", GROUP_ID)
|
||||
.set("client.id", GROUP_ID.to_string() + &app_args.cgw_id.to_string())
|
||||
.set("group.instance.id", app_args.cgw_id.to_string())
|
||||
.set("bootstrap.servers", app_args.kafka_ip.to_string() + ":" + &app_args.kafka_port.to_string())
|
||||
.set("enable.partition.eof", "false")
|
||||
.set("session.timeout.ms", "6000")
|
||||
.set("enable.auto.commit", "true")
|
||||
//.set("statistics.interval.ms", "30000")
|
||||
//.set("auto.offset.reset", "smallest")
|
||||
.set_log_level(RDKafkaLogLevel::Debug)
|
||||
.create_with_context(context)
|
||||
.expect("Consumer creation failed");
|
||||
|
||||
consumer
|
||||
.subscribe(&CONSUMER_TOPICS)
|
||||
.expect("Failed to subscribe to {CONSUMER_TOPICS} topics");
|
||||
|
||||
info!("Connected to kafka broker");
|
||||
|
||||
consumer
|
||||
}
|
||||
}
|
||||
|
||||
impl CGWCNCProducer {
|
||||
pub fn new() -> Self {
|
||||
let prod: CGWCNCProducerType = Self::create_producer();
|
||||
CGWCNCProducer {
|
||||
p: prod,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_producer() -> CGWCNCProducerType {
|
||||
let producer: FutureProducer = ClientConfig::new()
|
||||
.set("bootstrap.servers", "172.20.10.136:9092")
|
||||
.set("message.timeout.ms", "5000")
|
||||
.create()
|
||||
.expect("Producer creation error");
|
||||
|
||||
producer
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CGWNBApiClient {
|
||||
working_runtime_handle: Runtime,
|
||||
cgw_server_tx_mbox: CGWConnectionServerMboxTx,
|
||||
prod: CGWCNCProducer,
|
||||
// TBD: stplit different implementators through a defined trait,
|
||||
// that implements async R W operations?
|
||||
}
|
||||
|
||||
impl CGWNBApiClient {
|
||||
pub fn new(app_args: &AppArgs, cgw_tx: &CGWConnectionServerMboxTx) -> Arc<Self> {
|
||||
let working_runtime_h = Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("cgw-nb-api-l")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let cl = Arc::new(CGWNBApiClient {
|
||||
working_runtime_handle: working_runtime_h,
|
||||
cgw_server_tx_mbox: cgw_tx.clone(),
|
||||
prod: CGWCNCProducer::new(),
|
||||
});
|
||||
|
||||
let cl_clone = cl.clone();
|
||||
let consumer: CGWCNCConsumer = CGWCNCConsumer::new(app_args);
|
||||
cl.working_runtime_handle.spawn(async move {
|
||||
loop {
|
||||
let cl_clone = cl_clone.clone();
|
||||
let stream_processor = consumer.c.stream().try_for_each(|borrowed_message| {
|
||||
let cl_clone = cl_clone.clone();
|
||||
async move {
|
||||
// Process each message
|
||||
// Borrowed messages can't outlive the consumer they are received from, so they need to
|
||||
// be owned in order to be sent to a separate thread.
|
||||
//record_owned_message_receipt(&owned_message).await;
|
||||
let owned = borrowed_message.detach();
|
||||
|
||||
let key = match owned.key_view::<str>() {
|
||||
None => "",
|
||||
Some(Ok(s)) => s,
|
||||
Some(Err(e)) => {
|
||||
warn!("Error while deserializing message payload: {:?}", e);
|
||||
""
|
||||
}
|
||||
};
|
||||
|
||||
let payload = match owned.payload_view::<str>() {
|
||||
None => "",
|
||||
Some(Ok(s)) => s,
|
||||
Some(Err(e)) => {
|
||||
warn!("Error while deserializing message payload: {:?}", e);
|
||||
""
|
||||
}
|
||||
};
|
||||
cl_clone.enqueue_mbox_message_to_cgw_server(key.to_string(), payload.to_string()).await;
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
stream_processor.await.expect("stream processing failed");
|
||||
}
|
||||
});
|
||||
|
||||
cl
|
||||
}
|
||||
|
||||
pub async fn enqueue_mbox_message_from_cgw_server(&self, key: String, payload: String) {
|
||||
let produce_future = self.prod.p.send(
|
||||
FutureRecord::to(&PRODUCER_TOPICS)
|
||||
.key(&key)
|
||||
.payload(&payload),
|
||||
Duration::from_secs(0),
|
||||
);
|
||||
match produce_future.await {
|
||||
Err((e, _)) => println!("Error: {:?}", e),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_mbox_message_to_cgw_server(&self, key: String, payload: String) {
|
||||
debug!("MBOX_OUT: EnqueueNewMessageFromNBAPIListener, k:{key}");
|
||||
let msg = CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, CGWConnectionNBAPIReqMsgOrigin::FromNBAPI);
|
||||
let _ = self.cgw_server_tx_mbox.send(msg);
|
||||
}
|
||||
}
|
||||
75
src/cgw_remote_client.rs
Normal file
75
src/cgw_remote_client.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use crate::{
|
||||
AppArgs,
|
||||
};
|
||||
|
||||
pub mod cgw_remote {
|
||||
tonic::include_proto!("cgw.remote");
|
||||
}
|
||||
|
||||
use tonic::{
|
||||
transport::{
|
||||
Uri,
|
||||
channel::{
|
||||
Channel,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use cgw_remote::{
|
||||
EnqueueRequest,
|
||||
remote_client::{
|
||||
RemoteClient,
|
||||
},
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
time::{
|
||||
Duration,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CGWRemoteClient {
|
||||
remote_client: RemoteClient<Channel>,
|
||||
}
|
||||
|
||||
impl CGWRemoteClient {
|
||||
pub fn new(hostname: String) -> Self {
|
||||
let uri = Uri::from_maybe_shared(hostname).unwrap();
|
||||
let r_channel = Channel::builder(uri)
|
||||
.timeout(Duration::from_secs(20))
|
||||
.connect_timeout(Duration::from_secs(20))
|
||||
.connect_lazy();
|
||||
|
||||
let client = RemoteClient::new(r_channel);
|
||||
|
||||
CGWRemoteClient {
|
||||
remote_client: client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn relay_request_stream(&self, stream: Vec<(String, String)>) -> Result<(), ()> {
|
||||
let mut cl_clone = self.remote_client.clone();
|
||||
|
||||
let mut messages: Vec<EnqueueRequest> = vec![];
|
||||
let mut it = stream.into_iter();
|
||||
|
||||
while let Some(x) = it.next() {
|
||||
messages.push(EnqueueRequest {
|
||||
key: x.0,
|
||||
req: x.1,
|
||||
});
|
||||
}
|
||||
|
||||
let rq = tonic::Request::new(tokio_stream::iter(messages.clone()));
|
||||
match cl_clone.enqueue_nbapi_request_stream(rq).await {
|
||||
Err(e) => {
|
||||
error!("Failed to relay req: {:?}", e);
|
||||
Err(())
|
||||
}
|
||||
Ok(r) => {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
559
src/cgw_remote_discovery.rs
Normal file
559
src/cgw_remote_discovery.rs
Normal file
@@ -0,0 +1,559 @@
|
||||
use crate::{
|
||||
AppArgs,
|
||||
cgw_remote_client::{
|
||||
CGWRemoteClient,
|
||||
},
|
||||
cgw_db_accessor::{
|
||||
CGWDBAccessor,
|
||||
CGWDBInfrastructureGroup,
|
||||
CGWDBInfra,
|
||||
},
|
||||
};
|
||||
|
||||
use std::{
|
||||
collections::{
|
||||
HashMap,
|
||||
},
|
||||
net::{
|
||||
Ipv4Addr,
|
||||
IpAddr,
|
||||
SocketAddr,
|
||||
},
|
||||
sync::{
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use redis_async::{
|
||||
resp_array,
|
||||
};
|
||||
|
||||
use eui48::{
|
||||
MacAddress,
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
sync::{
|
||||
RwLock,
|
||||
}
|
||||
};
|
||||
|
||||
// Used in remote lookup
|
||||
static REDIS_KEY_SHARD_ID_PREFIX:&'static str = "shard_id_";
|
||||
static REDIS_KEY_SHARD_ID_FIELDS_NUM: usize = 12;
|
||||
static REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM:&'static str = "assigned_groups_num";
|
||||
|
||||
// Used in group assign / reassign
|
||||
static REDIS_KEY_GID:&'static str = "group_id_";
|
||||
static REDIS_KEY_GID_VALUE_GID:&'static str = "gid";
|
||||
static REDIS_KEY_GID_VALUE_SHARD_ID:&'static str = "shard_id";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CGWREDISDBShard {
|
||||
id: i32,
|
||||
server_ip: IpAddr,
|
||||
server_port: u16,
|
||||
assigned_groups_num: i32,
|
||||
capacity: i32,
|
||||
threshold: i32,
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for CGWREDISDBShard {
|
||||
fn from(values: Vec<String>) -> Self {
|
||||
assert!(values.len() >= REDIS_KEY_SHARD_ID_FIELDS_NUM,
|
||||
"Unexpected size of parsed vector: at least {REDIS_KEY_SHARD_ID_FIELDS_NUM} expected");
|
||||
assert!(values[0] == "id", "redis.res[0] != id, unexpected.");
|
||||
assert!(values[2] == "server_ip", "redis.res[2] != server_ip, unexpected.");
|
||||
assert!(values[4] == "server_port", "redis.res[4] != server_port, unexpected.");
|
||||
assert!(values[6] == "assigned_groups_num", "redis.res[6] != assigned_groups_num, unexpected.");
|
||||
assert!(values[8] == "capacity", "redis.res[8] != capacity, unexpected.");
|
||||
assert!(values[10] == "threshold", "redis.res[10] != threshold, unexpected.");
|
||||
|
||||
CGWREDISDBShard {
|
||||
id: values[1].parse::<i32>().unwrap(),
|
||||
server_ip: values[3].parse::<IpAddr>().unwrap(),
|
||||
server_port: values[5].parse::<u16>().unwrap(),
|
||||
assigned_groups_num: values[7].parse::<i32>().unwrap(),
|
||||
capacity: values[9].parse::<i32>().unwrap(),
|
||||
threshold: values[11].parse::<i32>().unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Vec<String>> for CGWREDISDBShard {
|
||||
fn into(self) -> Vec<String> {
|
||||
vec!["id".to_string(), self.id.to_string(),
|
||||
"server_ip".to_string(), self.server_ip.to_string(),
|
||||
"server_port".to_string(), self.server_port.to_string(),
|
||||
"assigned_groups_num".to_string(), self.assigned_groups_num.to_string(),
|
||||
"capacity".to_string(), self.capacity.to_string(),
|
||||
"threshold".to_string(), self.threshold.to_string()]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CGWRemoteConfig {
|
||||
pub remote_id: i32,
|
||||
pub server_ip: Ipv4Addr,
|
||||
pub server_port: u16,
|
||||
}
|
||||
|
||||
impl CGWRemoteConfig {
|
||||
pub fn new(id: i32, ip_conf: Ipv4Addr, port: u16) -> Self {
|
||||
CGWRemoteConfig {
|
||||
remote_id: id,
|
||||
server_ip: ip_conf,
|
||||
server_port: port,
|
||||
}
|
||||
}
|
||||
pub fn to_socket_addr(&self) -> SocketAddr {
|
||||
SocketAddr::new(std::net::IpAddr::V4(self.server_ip), self.server_port)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CGWRemoteIface {
|
||||
pub shard: CGWREDISDBShard,
|
||||
client: CGWRemoteClient,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CGWRemoteDiscovery {
|
||||
db_accessor: Arc<CGWDBAccessor>,
|
||||
redis_client: redis_async::client::paired::PairedConnection,
|
||||
gid_to_cgw_cache: Arc<RwLock::<HashMap<i32, i32>>>,
|
||||
remote_cgws_map: Arc<RwLock::<HashMap<i32, CGWRemoteIface>>>,
|
||||
local_shard_id: i32,
|
||||
}
|
||||
|
||||
impl CGWRemoteDiscovery {
|
||||
pub async fn new(app_args: &AppArgs) -> Self {
|
||||
let rc = CGWRemoteDiscovery {
|
||||
db_accessor: Arc::new(CGWDBAccessor::new(app_args).await),
|
||||
redis_client: redis_async::client::paired::paired_connect(
|
||||
app_args.redis_db_ip.to_string(),
|
||||
app_args.redis_db_port).await.unwrap(),
|
||||
gid_to_cgw_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
local_shard_id: app_args.cgw_id,
|
||||
remote_cgws_map: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
|
||||
let _ = rc.sync_gid_to_cgw_map().await;
|
||||
let _ = rc.sync_remote_cgw_map().await;
|
||||
|
||||
if let None = rc.remote_cgws_map.read().await.get(&rc.local_shard_id) {
|
||||
let redisdb_shard_info = CGWREDISDBShard {
|
||||
id: app_args.cgw_id,
|
||||
server_ip: std::net::IpAddr::V4(app_args.grpc_ip.clone()),
|
||||
server_port: u16::try_from(app_args.grpc_port).unwrap(),
|
||||
assigned_groups_num: 0i32,
|
||||
capacity: 1000i32,
|
||||
threshold: 50i32,
|
||||
};
|
||||
let redis_req_data: Vec<String> = redisdb_shard_info.into();
|
||||
|
||||
let _ = rc.redis_client.send::<i32>(
|
||||
resp_array!["DEL",
|
||||
format!("{REDIS_KEY_SHARD_ID_PREFIX}{}", app_args.cgw_id)]).await;
|
||||
|
||||
if let Err(e) = rc.redis_client.send::<String>(
|
||||
resp_array!["HSET", format!("{REDIS_KEY_SHARD_ID_PREFIX}{}", app_args.cgw_id)]
|
||||
.append(redis_req_data)).await {
|
||||
panic!("Failed to create record about shard in REDIS, e:{e}");
|
||||
}
|
||||
|
||||
let _ = rc.sync_remote_cgw_map().await;
|
||||
}
|
||||
|
||||
debug!("Found {} remote CGWs:", rc.remote_cgws_map.read().await.len() - 1);
|
||||
|
||||
for (key, val) in rc.remote_cgws_map.read().await.iter() {
|
||||
if val.shard.id == rc.local_shard_id {
|
||||
continue;
|
||||
}
|
||||
debug!("Shard #{}, IP {}:{}", val.shard.id, val.shard.server_ip, val.shard.server_port);
|
||||
}
|
||||
|
||||
rc
|
||||
}
|
||||
|
||||
pub async fn sync_gid_to_cgw_map(&self) {
|
||||
let mut lock = self.gid_to_cgw_cache.write().await;
|
||||
|
||||
// Clear hashmap
|
||||
lock.clear();
|
||||
|
||||
let redis_keys: Vec<String> =
|
||||
match self.redis_client.send::<Vec<String>>(
|
||||
resp_array!["KEYS", format!("{}*", REDIS_KEY_GID)]).await {
|
||||
Err(e) => {
|
||||
panic!("Failed to get KEYS list from REDIS, e:{e}");
|
||||
},
|
||||
Ok(r) => r
|
||||
};
|
||||
|
||||
for key in redis_keys {
|
||||
let gid : i32 = match self.redis_client.send::<String>(
|
||||
resp_array!["HGET", &key, REDIS_KEY_GID_VALUE_GID]).await {
|
||||
Ok(res) => {
|
||||
match res.parse::<i32>() {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
warn!("Found proper key '{key}' entry, but failed to parse GID from it:\n{e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Found proper key '{key}' entry, but failed to fetch GID from it:\n{e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let shard_id : i32 = match self.redis_client.send::<String>(
|
||||
resp_array!["HGET", &key, REDIS_KEY_GID_VALUE_SHARD_ID]).await {
|
||||
Ok(res) => {
|
||||
match res.parse::<i32>() {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
warn!("Found proper key '{key}' entry, but failed to parse SHARD_ID from it:\n{e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Found proper key '{key}' entry, but failed to fetch SHARD_ID from it:\n{e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Found group {key}, gid: {gid}, shard_id: {shard_id}");
|
||||
|
||||
match lock.insert(gid, shard_id) {
|
||||
None => continue,
|
||||
Some(v) => warn!("Populated gid_to_cgw_map with previous value being alerady set, unexpected")
|
||||
}
|
||||
}
|
||||
debug!("Found total {} groups with their respective owners", lock.len());
|
||||
}
|
||||
|
||||
async fn sync_remote_cgw_map(&self) -> Result<(), &'static str> {
|
||||
let mut lock = self.remote_cgws_map.write().await;
|
||||
|
||||
// Clear hashmap
|
||||
lock.clear();
|
||||
|
||||
let redis_keys: Vec<String> =
|
||||
match self.redis_client.send::<Vec<String>>(
|
||||
resp_array!["KEYS", format!("{}*", REDIS_KEY_SHARD_ID_PREFIX)]).await {
|
||||
Err(e) => {
|
||||
warn!("Failed to get cgw shard KEYS list from REDIS, e:{e}");
|
||||
return Err("Remote CGW Shards list fetch from REDIS failed");
|
||||
},
|
||||
Ok(r) => r
|
||||
};
|
||||
|
||||
for key in redis_keys {
|
||||
match self.redis_client.send::<Vec<String>>(
|
||||
resp_array!["HGETALL", &key]).await {
|
||||
Ok(res) => {
|
||||
let shrd: CGWREDISDBShard = match CGWREDISDBShard::try_from(res) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse CGWREDISDBShard, {key}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint_str =
|
||||
String::from("http://") +
|
||||
&shrd.server_ip.to_string() +
|
||||
":" +
|
||||
&shrd.server_port.to_string();
|
||||
let cgw_iface = CGWRemoteIface {
|
||||
shard: shrd,
|
||||
client: CGWRemoteClient::new(endpoint_str),
|
||||
};
|
||||
lock.insert(cgw_iface.shard.id, cgw_iface);
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Found proper key '{key}' entry, but failed to fetch Shard info from it:\n{e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_infra_group_owner_id(&self, gid: i32) -> Option<i32> {
|
||||
// try to use internal cache first
|
||||
if let Some(id) = self.gid_to_cgw_cache.read().await.get(&gid) {
|
||||
return Some(*id);
|
||||
}
|
||||
|
||||
// then try to use redis
|
||||
self.sync_gid_to_cgw_map().await;
|
||||
|
||||
if let Some(id) = self.gid_to_cgw_cache.read().await.get(&gid) {
|
||||
return Some(*id);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
async fn increment_cgw_assigned_groups_num(&self, id: i32) -> Result<(), &'static str> {
|
||||
debug!("inc {id}");
|
||||
/*
|
||||
if let Err(e) = self.redis_client.send::<i32>(
|
||||
*/
|
||||
match self.redis_client.send::<i32>(
|
||||
resp_array!["HINCRBY",
|
||||
format!("{}{id}", REDIS_KEY_SHARD_ID_PREFIX),
|
||||
REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM,
|
||||
"1"]).await {
|
||||
Ok(v) => {
|
||||
debug!("ret {:?}", v);
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Failed to increment CGW{id} assigned group num count, e:{e}");
|
||||
return Err("Failed to increment assigned group num count");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn decrement_cgw_assigned_groups_num(&self, id: i32) -> Result<(), &'static str> {
|
||||
if let Err(e) = self.redis_client.send::<i32>(
|
||||
resp_array!["HINCRBY",
|
||||
format!("{}{id}", REDIS_KEY_SHARD_ID_PREFIX),
|
||||
REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM,
|
||||
"-1"]).await {
|
||||
warn!("Failed to decrement CGW{id} assigned group num count, e:{e}");
|
||||
return Err("Failed to decrement assigned group num count");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_infra_group_cgw_assignee(&self) -> Result<i32, &'static str> {
|
||||
let lock = self.remote_cgws_map.read().await;
|
||||
let mut hash_vec: Vec<(&i32, &CGWRemoteIface)> = lock.iter().collect();
|
||||
|
||||
hash_vec.sort_by(|a, b| b.1.shard.assigned_groups_num.cmp(&a.1.shard.assigned_groups_num));
|
||||
|
||||
for x in hash_vec {
|
||||
debug!("id_{} capacity {} t {} assigned {}",
|
||||
x.1.shard.id, x.1.shard.capacity,
|
||||
x.1.shard.threshold,
|
||||
x.1.shard.assigned_groups_num);
|
||||
let max_capacity: i32 = x.1.shard.capacity + x.1.shard.threshold;
|
||||
if x.1.shard.assigned_groups_num + 1 <= max_capacity {
|
||||
debug!("Found CGW shard to assign group to (id {})", x.1.shard.id);
|
||||
return Ok(x.1.shard.id);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
warn!("Every available CGW is exceeding capacity+threshold limit, using least loaded one...");
|
||||
if let Some(least_loaded_cgw) = lock.iter()
|
||||
.min_by(|a, b| a.1.shard.assigned_groups_num.cmp(&b.1.shard.assigned_groups_num))
|
||||
.map(|(k, _v)| _v) {
|
||||
warn!("Found least loaded CGW id: {}", least_loaded_cgw.shard.id);
|
||||
return Ok(least_loaded_cgw.shard.id);
|
||||
}
|
||||
|
||||
return Err("Unexpected: Failed to find the least loaded CGW shard");
|
||||
}
|
||||
|
||||
async fn assign_infra_group_to_cgw(&self, gid: i32) -> Result<i32, &'static str> {
|
||||
// Delete key (if exists), recreate with new owner
|
||||
let _ = self.deassign_infra_group_to_cgw(gid).await;
|
||||
|
||||
let dst_cgw_id: i32 = match self.get_infra_group_cgw_assignee().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("Failed to assign {gid} to any shard, reason:{e}");
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = self.redis_client.send::<String>(
|
||||
resp_array!["HSET",
|
||||
format!("{REDIS_KEY_GID}{gid}"),
|
||||
REDIS_KEY_GID_VALUE_GID,
|
||||
gid.to_string(),
|
||||
REDIS_KEY_GID_VALUE_SHARD_ID,
|
||||
dst_cgw_id.to_string()]).await {
|
||||
error!("Failed to update REDIS gid{gid} owner to shard{dst_cgw_id}, e:{e}");
|
||||
return Err("Hot-cache (REDIS DB) update owner failed");
|
||||
}
|
||||
|
||||
let mut lock = self.gid_to_cgw_cache.write().await;
|
||||
lock.insert(gid, dst_cgw_id);
|
||||
|
||||
debug!("REDIS: assigned gid{gid} to shard{dst_cgw_id}");
|
||||
|
||||
Ok(dst_cgw_id)
|
||||
}
|
||||
|
||||
pub async fn deassign_infra_group_to_cgw(&self, gid: i32) -> Result<(), &'static str> {
|
||||
if let Err(e) = self.redis_client.send::<i64>(
|
||||
resp_array!["DEL", format!("{REDIS_KEY_GID}{gid}")]).await {
|
||||
error!("Failed to deassigned REDIS gid{gid} owner, e:{e}");
|
||||
return Err("Hot-cache (REDIS DB) deassign owner failed");
|
||||
}
|
||||
|
||||
debug!("REDIS: deassigned gid{gid} from controlled CGW");
|
||||
|
||||
let mut lock = self.gid_to_cgw_cache.write().await;
|
||||
lock.remove(&gid);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_infra_group(&self, g: &CGWDBInfrastructureGroup) -> Result<i32, &'static str> {
|
||||
//TODO: transaction-based insert/assigned_group_num update (DB)
|
||||
let rc = self.db_accessor.insert_new_infra_group(g).await;
|
||||
if let Err(e) = rc {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
let shard_id: i32 = match self.assign_infra_group_to_cgw(g.id).await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
let _ = self.db_accessor.delete_infra_group(g.id).await;
|
||||
return Err("Assign group to CGW shard failed");
|
||||
}
|
||||
};
|
||||
|
||||
let rc = self.increment_cgw_assigned_groups_num(shard_id).await;
|
||||
if let Err(e) = rc {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
Ok(shard_id)
|
||||
}
|
||||
|
||||
pub async fn destroy_infra_group(&self, gid: i32) -> Result<(), &'static str> {
|
||||
let cgw_id: Option<i32> = self.get_infra_group_owner_id(gid).await;
|
||||
if let Some(id) = cgw_id {
|
||||
let _ = self.deassign_infra_group_to_cgw(gid).await;
|
||||
let _ = self.decrement_cgw_assigned_groups_num(id).await;
|
||||
}
|
||||
|
||||
//TODO: transaction-based insert/assigned_group_num update (DB)
|
||||
let rc = self.db_accessor.delete_infra_group(gid).await;
|
||||
if let Err(e) = rc {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_ifras_list(&self, gid: i32, infras: Vec<String>) -> Result<(), Vec<String>> {
|
||||
// TODO: assign list to shards; currently - only created bulk, no assignment
|
||||
let mut futures = Vec::with_capacity(infras.len());
|
||||
// Results store vec of MACs we failed to add
|
||||
let mut failed_infras: Vec<String> = Vec::with_capacity(futures.len());
|
||||
for x in infras.iter() {
|
||||
let db_accessor_clone = self.db_accessor.clone();
|
||||
let infra = CGWDBInfra {
|
||||
mac: MacAddress::parse_str(&x).unwrap(),
|
||||
infra_group_id: gid,
|
||||
};
|
||||
futures.push(
|
||||
tokio::spawn(
|
||||
async move {
|
||||
if let Err(_) = db_accessor_clone.insert_new_infra(&infra).await {
|
||||
Err(infra.mac.to_string(eui48::MacAddressFormat::HexString))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for (i, future) in futures.iter_mut().enumerate() {
|
||||
match future.await {
|
||||
Ok(res) => {
|
||||
if let Err(mac) = res {
|
||||
failed_infras.push(mac);
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
failed_infras.push(infras[i].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if failed_infras.len() > 0 {
|
||||
return Err(failed_infras);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn destroy_ifras_list(&self, gid: i32, infras: Vec<String>) -> Result<(), Vec<String>> {
|
||||
let mut futures = Vec::with_capacity(infras.len());
|
||||
// Results store vec of MACs we failed to add
|
||||
let mut failed_infras: Vec<String> = Vec::with_capacity(futures.len());
|
||||
for x in infras.iter() {
|
||||
let db_accessor_clone = self.db_accessor.clone();
|
||||
let mac = MacAddress::parse_str(&x).unwrap();
|
||||
futures.push(
|
||||
tokio::spawn(
|
||||
async move {
|
||||
if let Err(_) = db_accessor_clone.delete_infra(mac).await {
|
||||
Err(mac.to_string(eui48::MacAddressFormat::HexString))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for (i, future) in futures.iter_mut().enumerate() {
|
||||
match future.await {
|
||||
Ok(res) => {
|
||||
if let Err(mac) = res {
|
||||
failed_infras.push(mac);
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
failed_infras.push(infras[i].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if failed_infras.len() > 0 {
|
||||
return Err(failed_infras);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn relay_request_stream_to_remote_cgw(
|
||||
&self,
|
||||
shard_id: i32,
|
||||
stream: Vec<(String, String)>)
|
||||
-> Result<(), ()> {
|
||||
// try to use internal cache first
|
||||
if let Some(cl) = self.remote_cgws_map.read().await.get(&shard_id) {
|
||||
if let Err(()) = cl.client.relay_request_stream(stream).await {
|
||||
return Err(())
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// then try to use redis
|
||||
let _ = self.sync_remote_cgw_map().await;
|
||||
|
||||
if let Some(cl) = self.remote_cgws_map.read().await.get(&shard_id) {
|
||||
if let Err(()) = cl.client.relay_request_stream(stream).await {
|
||||
return Err(())
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
error!("No suitable CGW instance #{shard_id} was discovered, cannot relay msg");
|
||||
return Err(());
|
||||
}
|
||||
}
|
||||
93
src/cgw_remote_server.rs
Normal file
93
src/cgw_remote_server.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use crate::{
|
||||
AppArgs,
|
||||
};
|
||||
|
||||
pub mod cgw_remote {
|
||||
tonic::include_proto!("cgw.remote");
|
||||
}
|
||||
|
||||
use tonic::{transport::Server, Request, Response, Status};
|
||||
|
||||
use cgw_remote::{
|
||||
EnqueueRequest,
|
||||
EnqueueResponse,
|
||||
remote_server::{
|
||||
RemoteServer,
|
||||
Remote,
|
||||
},
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
time::{
|
||||
Duration,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cgw_remote_discovery::{
|
||||
CGWRemoteConfig,
|
||||
};
|
||||
|
||||
use crate::cgw_connection_server:: {
|
||||
CGWConnectionServer,
|
||||
};
|
||||
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
struct CGWRemote {
|
||||
cgw_srv: Arc<CGWConnectionServer>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl Remote for CGWRemote {
|
||||
async fn enqueue_nbapi_request_stream(&self, request: Request<tonic::Streaming<EnqueueRequest>>, ) -> Result<Response<EnqueueResponse>, Status> {
|
||||
let mut rq_stream = request.into_inner();
|
||||
while let Some(rq) = rq_stream.next().await {
|
||||
let rq = rq.unwrap();
|
||||
self.cgw_srv.enqueue_mbox_relayed_message_to_cgw_server(rq.key, rq.req).await;
|
||||
}
|
||||
|
||||
let reply = cgw_remote::EnqueueResponse {
|
||||
ret: 0,
|
||||
msg: "DONE".to_string(),
|
||||
};
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CGWRemoteServer {
|
||||
cfg: CGWRemoteConfig,
|
||||
}
|
||||
|
||||
impl CGWRemoteServer {
|
||||
pub fn new(app_args: &AppArgs) -> Self {
|
||||
let remote_cfg = CGWRemoteConfig::new(app_args.cgw_id, app_args.grpc_ip, app_args.grpc_port);
|
||||
let remote_server = CGWRemoteServer {
|
||||
cfg: remote_cfg,
|
||||
};
|
||||
remote_server
|
||||
}
|
||||
pub async fn start(&self, srv: Arc<CGWConnectionServer>) {
|
||||
// GRPC server
|
||||
// TODO: use CGWRemoteServerConfig wrap
|
||||
let icgw_serv = CGWRemote{
|
||||
cgw_srv: srv,
|
||||
};
|
||||
let grpc_srv = Server::builder()
|
||||
.tcp_keepalive(Some(Duration::from_secs(7200)))
|
||||
.http2_keepalive_timeout(Some(Duration::from_secs(7200)))
|
||||
.http2_keepalive_interval(Some(Duration::from_secs(7200)))
|
||||
.http2_max_pending_accept_reset_streams(Some(1000))
|
||||
.add_service(RemoteServer::new(icgw_serv));
|
||||
|
||||
info!("Starting GRPC server id {} - listening at {}:{}", self.cfg.remote_id, self.cfg.server_ip, self.cfg.server_port);
|
||||
let res = grpc_srv.serve(self.cfg.to_socket_addr()).await;
|
||||
error!("grpc server returned {:?}", res);
|
||||
// end of GRPC server build / start declaration
|
||||
}
|
||||
}
|
||||
24
src/localhost.crt
Executable file
24
src/localhost.crt
Executable file
@@ -0,0 +1,24 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u
|
||||
eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTE2MDgxMzE2MDcwNFoX
|
||||
DTIyMDIwMzE2MDcwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G
|
||||
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCpVhh1/FNP2qvWenbZSghari/UThwe
|
||||
dynfnHG7gc3JmygkEdErWBO/CHzHgsx7biVE5b8sZYNEDKFojyoPHGWK2bQM/FTy
|
||||
niJCgNCLdn6hUqqxLAml3cxGW77hAWu94THDGB1qFe+eFiAUnDmob8gNZtAzT6Ky
|
||||
b/JGJdrEU0wj+Rd7wUb4kpLInNH/Jc+oz2ii2AjNbGOZXnRz7h7Kv3sO9vABByYe
|
||||
LcCj3qnhejHMqVhbAT1MD6zQ2+YKBjE52MsQKU/xhUpu9KkUyLh0cxkh3zrFiKh4
|
||||
Vuvtc+n7aeOv2jJmOl1dr0XLlSHBlmoKqH6dCTSbddQLmlK7dms8vE01AgMBAAGj
|
||||
gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFMeUzGYV
|
||||
bXwJNQVbY1+A8YXYZY8pMEIGA1UdIwQ7MDmAFJvEsUi7+D8vp8xcWvnEdVBGkpoW
|
||||
oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO
|
||||
dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0
|
||||
MA0GCSqGSIb3DQEBCwUAA4IBgQBsk5ivAaRAcNgjc7LEiWXFkMg703AqDDNx7kB1
|
||||
RDgLalLvrjOfOp2jsDfST7N1tKLBSQ9bMw9X4Jve+j7XXRUthcwuoYTeeo+Cy0/T
|
||||
1Q78ctoX74E2nB958zwmtRykGrgE/6JAJDwGcgpY9kBPycGxTlCN926uGxHsDwVs
|
||||
98cL6ZXptMLTR6T2XP36dAJZuOICSqmCSbFR8knc/gjUO36rXTxhwci8iDbmEVaf
|
||||
BHpgBXGU5+SQ+QM++v6bHGf4LNQC5NZ4e4xvGax8ioYu/BRsB/T3Lx+RlItz4zdU
|
||||
XuxCNcm3nhQV2ZHquRdbSdoyIxV5kJXel4wCmOhWIq7A2OBKdu5fQzIAzzLi65EN
|
||||
RPAKsKB4h7hGgvciZQ7dsMrlGw0DLdJ6UrFyiR5Io7dXYT/+JP91lP5xsl6Lhg9O
|
||||
FgALt7GSYRm2cZdgi9pO9rRr83Br1VjQT1vHz6yoZMXSqc4A2zcN2a2ZVq//rHvc
|
||||
FZygs8miAhWPzqnpmgTj1cPiU1M=
|
||||
-----END CERTIFICATE-----
|
||||
275
src/main.rs
Normal file
275
src/main.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
#![warn(rust_2018_idioms)]
|
||||
mod cgw_connection_server;
|
||||
mod cgw_nb_api_listener;
|
||||
mod cgw_connection_processor;
|
||||
mod cgw_remote_discovery;
|
||||
mod cgw_remote_server;
|
||||
mod cgw_remote_client;
|
||||
mod cgw_db_accessor;
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
use tokio::{
|
||||
net::{
|
||||
TcpListener,
|
||||
},
|
||||
time::{
|
||||
sleep,
|
||||
Duration,
|
||||
},
|
||||
runtime::{
|
||||
Builder,
|
||||
Handle,
|
||||
Runtime,
|
||||
},
|
||||
};
|
||||
|
||||
use native_tls::Identity;
|
||||
use std::{
|
||||
net::{
|
||||
Ipv4Addr,
|
||||
SocketAddr,
|
||||
},
|
||||
sync::{
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use rlimit::{
|
||||
setrlimit,
|
||||
Resource,
|
||||
};
|
||||
|
||||
use cgw_connection_server::{
|
||||
CGWConnectionServer,
|
||||
};
|
||||
|
||||
use cgw_remote_server::{
|
||||
CGWRemoteServer,
|
||||
};
|
||||
|
||||
use clap::{
|
||||
ValueEnum,
|
||||
Parser,
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||
enum AppCoreLogLevel {
|
||||
/// Print debug-level messages and above
|
||||
Debug,
|
||||
/// Print info-level messages and above
|
||||
///
|
||||
///
|
||||
Info,
|
||||
}
|
||||
|
||||
/// CGW server
|
||||
#[derive(Parser, Clone)]
|
||||
#[command(version, about, long_about = None)]
|
||||
pub struct AppArgs {
|
||||
/// CGW unique identifier (u64)
|
||||
#[arg(short, long, default_value_t = 0)]
|
||||
cgw_id: i32,
|
||||
|
||||
/// Number of thread in a threadpool dedicated for handling secure websocket connections
|
||||
#[arg(short, long, default_value_t = 4)]
|
||||
wss_t_num: usize,
|
||||
/// Loglevel of application
|
||||
#[arg(value_enum, default_value_t = AppCoreLogLevel::Debug)]
|
||||
log_level: AppCoreLogLevel,
|
||||
|
||||
/// IP to listen for incoming WSS connection
|
||||
#[arg(long, default_value_t = Ipv4Addr::new(0, 0, 0, 0))]
|
||||
wss_ip: Ipv4Addr,
|
||||
/// PORT to listen for incoming WSS connection
|
||||
#[arg(long, default_value_t = 15002)]
|
||||
wss_port: u16,
|
||||
|
||||
/// IP to listen for incoming GRPC connection
|
||||
#[arg(long, default_value_t = Ipv4Addr::new(0, 0, 0, 0))]
|
||||
grpc_ip: Ipv4Addr,
|
||||
/// PORT to listen for incoming GRPC connection
|
||||
#[arg(long, default_value_t = 50051)]
|
||||
grpc_port: u16,
|
||||
|
||||
/// IP to connect to KAFKA broker
|
||||
#[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))]
|
||||
kafka_ip: Ipv4Addr,
|
||||
/// PORT to connect to KAFKA broker
|
||||
#[arg(long, default_value_t = 9092)]
|
||||
kafka_port: u16,
|
||||
/// KAFKA topic from where to consume messages
|
||||
#[arg(long, default_value_t = String::from("CnC"))]
|
||||
kafka_consume_topic: String,
|
||||
/// KAFKA topic where to produce messages
|
||||
#[arg(long, default_value_t = String::from("CnC_Res"))]
|
||||
kafka_produce_topic: String,
|
||||
|
||||
/// IP to connect to DB (PSQL)
|
||||
#[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))]
|
||||
db_ip: Ipv4Addr,
|
||||
/// PORT to connect to DB (PSQL)
|
||||
#[arg(long, default_value_t = 5432)]
|
||||
db_port: u16,
|
||||
/// DB name to connect to in DB (PSQL)
|
||||
#[arg(long, default_value_t = String::from("cgw"))]
|
||||
db_name: String,
|
||||
/// DB user name use with connection to in DB (PSQL)
|
||||
#[arg(long, default_value_t = String::from("cgw"))]
|
||||
db_username: String,
|
||||
/// DB user password use with connection to in DB (PSQL)
|
||||
#[arg(long, default_value_t = String::from("123"))]
|
||||
db_password: String,
|
||||
|
||||
/// IP to connect to DB (REDIS)
|
||||
#[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))]
|
||||
redis_db_ip: Ipv4Addr,
|
||||
/// PORT to connect to DB (REDIS)
|
||||
#[arg(long, default_value_t = 6379)]
|
||||
redis_db_port: u16,
|
||||
}
|
||||
|
||||
pub struct AppCore {
|
||||
cgw_server: Arc<CGWConnectionServer>,
|
||||
main_runtime_handle: Arc<Handle>,
|
||||
grpc_server_runtime_handle: Arc<Runtime>,
|
||||
conn_ack_runtime_handle: Arc<Runtime>,
|
||||
args: AppArgs,
|
||||
}
|
||||
|
||||
impl AppCore {
|
||||
async fn new(app_args: AppArgs) -> Self {
|
||||
Self::setup_app(&app_args);
|
||||
let current_runtime = Arc::new(Handle::current());
|
||||
|
||||
let c_ack_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("cgw-c-ack")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
let rpc_runtime_handle = Arc::new(Builder::new_multi_thread()
|
||||
.worker_threads(1)
|
||||
.thread_name("grpc-recv-t")
|
||||
.thread_stack_size(1 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap());
|
||||
let app_core = AppCore {
|
||||
cgw_server: CGWConnectionServer::new(&app_args).await,
|
||||
main_runtime_handle: current_runtime,
|
||||
conn_ack_runtime_handle: c_ack_runtime_handle,
|
||||
args: app_args,
|
||||
grpc_server_runtime_handle: rpc_runtime_handle
|
||||
};
|
||||
app_core
|
||||
}
|
||||
|
||||
fn setup_app(args: &AppArgs) {
|
||||
let nofile_rlimit = Resource::NOFILE.get().unwrap();
|
||||
println!("{:?}", nofile_rlimit);
|
||||
let nofile_hard_limit = nofile_rlimit.1;
|
||||
assert!(setrlimit(Resource::NOFILE, nofile_hard_limit, nofile_hard_limit).is_ok());
|
||||
let nofile_rlimit = Resource::NOFILE.get().unwrap();
|
||||
println!("{:?}", nofile_rlimit);
|
||||
|
||||
match args.log_level {
|
||||
AppCoreLogLevel::Debug => ::std::env::set_var("RUST_LOG", "ucentral_cgw=debug"),
|
||||
AppCoreLogLevel::Info => ::std::env::set_var("RUST_LOG", "ucentral_cgw=info"),
|
||||
}
|
||||
env_logger::init();
|
||||
}
|
||||
|
||||
async fn run(self: Arc<AppCore>) {
|
||||
let main_runtime_handle = self.main_runtime_handle.clone();
|
||||
let core_clone = self.clone();
|
||||
|
||||
let cgw_remote_server = CGWRemoteServer::new(&self.args);
|
||||
let cgw_srv_clone = self.cgw_server.clone();
|
||||
self.grpc_server_runtime_handle.spawn(async move {
|
||||
debug!("cgw_remote_server.start entry");
|
||||
cgw_remote_server.start(cgw_srv_clone).await;
|
||||
debug!("cgw_remote_server.start exit");
|
||||
});
|
||||
|
||||
main_runtime_handle.spawn(async move {
|
||||
server_loop(core_clone).await
|
||||
});
|
||||
|
||||
// TODO:
|
||||
// Add signal processing and etcetera app-related handlers.
|
||||
loop {
|
||||
sleep(Duration::from_millis(5000)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: a method of an object (TlsAcceptor? CGWConnectionServer?), not a plain function
|
||||
async fn server_loop(app_core: Arc<AppCore>) -> () {
|
||||
debug!("sever_loop entry");
|
||||
|
||||
debug!("Starting WSS server, listening at {}:{}", app_core.args.wss_ip, app_core.args.wss_port);
|
||||
// Bind the server's socket
|
||||
let sockaddraddr = SocketAddr::new(std::net::IpAddr::V4(app_core.args.wss_ip), app_core.args.wss_port);
|
||||
let listener: Arc<TcpListener> = match TcpListener::bind(sockaddraddr).await {
|
||||
Ok(listener) => Arc::new(listener),
|
||||
Err(e) => panic!("listener bind failed {e}"),
|
||||
};
|
||||
|
||||
info!("Started WSS server.");
|
||||
// Create the TLS acceptor.
|
||||
// TODO: custom acceptor
|
||||
let der = include_bytes!("localhost.crt");
|
||||
let key = include_bytes!("localhost.key");
|
||||
let cert = match Identity::from_pkcs8(der, key) {
|
||||
Ok(cert) => cert,
|
||||
Err(e) => panic!("Cannot create SSL identity from supplied cert\n{e}")
|
||||
};
|
||||
|
||||
let tls_acceptor =
|
||||
tokio_native_tls::TlsAcceptor::from(
|
||||
match native_tls::TlsAcceptor::builder(cert).build() {
|
||||
Ok(builder) => builder,
|
||||
Err(e) => panic!("Cannot create SSL-acceptor from supplied cert\n{e}")
|
||||
});
|
||||
|
||||
// Spawn explicitly in main thread: created task accepts connection,
|
||||
// but handling is spawned inside another threadpool runtime
|
||||
let app_core_clone = app_core.clone();
|
||||
let _ = app_core.main_runtime_handle.spawn(async move {
|
||||
let mut conn_idx: i64 = 0;
|
||||
loop {
|
||||
let app_core_clone = app_core_clone.clone();
|
||||
let cgw_server_clone = app_core_clone.cgw_server.clone();
|
||||
let tls_acceptor_clone = tls_acceptor.clone();
|
||||
|
||||
// Asynchronously wait for an inbound socket.
|
||||
let (socket, remote_addr) = match listener.accept().await {
|
||||
Ok((sock, addr)) => (sock, addr),
|
||||
Err(e) => {
|
||||
error!("Failed to Accept conn {e}\n");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: we control tls_acceptor, thus at this stage it's our responsibility
|
||||
// to provide underlying certificates common name inside the ack_connection func
|
||||
// (CN == mac address of device)
|
||||
app_core_clone.conn_ack_runtime_handle.spawn(async move {
|
||||
cgw_server_clone.ack_connection(socket, tls_acceptor_clone, remote_addr, conn_idx).await;
|
||||
});
|
||||
|
||||
conn_idx += 1;
|
||||
}
|
||||
}).await;
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() {
|
||||
let args = AppArgs::parse();
|
||||
let app = Arc::new(AppCore::new(args).await);
|
||||
|
||||
app.run().await;
|
||||
}
|
||||
17
src/proto/cgw.proto
Normal file
17
src/proto/cgw.proto
Normal file
@@ -0,0 +1,17 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package cgw.remote;
|
||||
|
||||
service Remote {
|
||||
rpc EnqueueNBAPIRequestStream(stream EnqueueRequest) returns (EnqueueResponse);
|
||||
}
|
||||
|
||||
message EnqueueRequest {
|
||||
string req = 1;
|
||||
string key = 2;
|
||||
}
|
||||
|
||||
message EnqueueResponse {
|
||||
uint64 ret = 1;
|
||||
string msg = 2;
|
||||
}
|
||||
1
utils/cert_generator/.gitignore
vendored
Normal file
1
utils/cert_generator/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
certs/
|
||||
18
utils/cert_generator/README.md
Normal file
18
utils/cert_generator/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generate Certificates
|
||||
|
||||
```
|
||||
# generate CA certs
|
||||
$ ./generate_certs.sh -a
|
||||
|
||||
# generate server certs
|
||||
$ ./generate_certs.sh -s
|
||||
|
||||
# generate 10 client certs
|
||||
$ ./generate_certs.sh -c 10
|
||||
|
||||
# generate 10 client certs with MAC addr AA:*
|
||||
$ ./generate_certs.sh -c 10 -m AA:XX:XX:XX:XX:XX
|
||||
```
|
||||
|
||||
The certificates will be available in `./certs/`.
|
||||
|
||||
19
utils/cert_generator/ca.conf
Normal file
19
utils/cert_generator/ca.conf
Normal file
@@ -0,0 +1,19 @@
|
||||
|
||||
[req]
|
||||
distinguished_name=req
|
||||
[ca]
|
||||
basicConstraints=CA:true
|
||||
keyUsage=keyCertSign,cRLSign
|
||||
subjectKeyIdentifier=hash
|
||||
[client]
|
||||
basicConstraints=CA:FALSE
|
||||
extendedKeyUsage=clientAuth
|
||||
keyUsage=digitalSignature,keyEncipherment
|
||||
[server]
|
||||
basicConstraints=CA:FALSE
|
||||
extendedKeyUsage=serverAuth
|
||||
keyUsage=digitalSignature,keyEncipherment
|
||||
subjectAltName=@sans
|
||||
[sans]
|
||||
DNS.1=localhost
|
||||
IP.1=127.0.0.1
|
||||
161
utils/cert_generator/generate_certs.sh
Executable file
161
utils/cert_generator/generate_certs.sh
Executable file
@@ -0,0 +1,161 @@
|
||||
#!/bin/bash
|
||||
TEMPLATE="02:XX:XX:XX:XX:XX"
|
||||
HEXCHARS="0123456789ABCDEF"
|
||||
CONF_FILE=ca.conf
|
||||
CERT_DIR=certs
|
||||
CA_DIR=$CERT_DIR/ca
|
||||
CA_CERT=$CA_DIR/ca.crt
|
||||
CA_KEY=$CA_DIR/ca.key
|
||||
SERVER_DIR=$CERT_DIR/server
|
||||
CLIENT_DIR=$CERT_DIR/client
|
||||
METHOD_FAST=y
|
||||
|
||||
usage()
|
||||
{
|
||||
echo "Usage: $0 [options]"
|
||||
echo
|
||||
echo "options:"
|
||||
echo "-h"
|
||||
echo -e "\tprint this help"
|
||||
echo "-a"
|
||||
echo -e "\tgenerate CA key and certificate"
|
||||
echo "-s"
|
||||
echo -e "\tgenerate server key and certificate; sign the certificate"
|
||||
echo -e "\tusing the CA certificate"
|
||||
echo "-c NUMBER"
|
||||
echo -e "\tgenerate *NUMBER* of client keys and certificates; sign"
|
||||
echo -e "\tall of the certificates using the CA certificate"
|
||||
echo "-m MASK"
|
||||
echo -e "\tspecify custom MAC addr mask"
|
||||
echo "-o"
|
||||
echo -e "\tslow mode"
|
||||
}
|
||||
|
||||
rand_mac()
|
||||
{
|
||||
local MAC=$TEMPLATE
|
||||
while [[ $MAC =~ "X" || $MAC =~ "x" ]]
|
||||
do
|
||||
MAC=${MAC/[xX]/${HEXCHARS:$(( $RANDOM % 16 )):1}}
|
||||
done
|
||||
echo $MAC
|
||||
}
|
||||
|
||||
gen_cert()
|
||||
{
|
||||
local req_file=$(mktemp)
|
||||
local pem=$(mktemp)
|
||||
local type=$1
|
||||
local cn=$2
|
||||
local key=$3
|
||||
local cert=$4
|
||||
# generate key and request to sign
|
||||
openssl req -config $CONF_FILE -x509 -nodes -newkey rsa:4096 -sha512 -days 365 \
|
||||
-extensions $type -subj "/CN=$cn" -out $req_file -keyout $key &> /dev/null
|
||||
# sign certificate
|
||||
openssl x509 -extfile $CONF_FILE -CA $CA_CERT -CAkey $CA_KEY -CAcreateserial -sha512 -days 365 \
|
||||
-in $req_file -out $pem
|
||||
if [ $? == "0" ]
|
||||
then
|
||||
cat $pem $CA_CERT > $cert
|
||||
else
|
||||
>&2 echo Failed to generate certificate
|
||||
rm $key
|
||||
fi
|
||||
rm $req_file
|
||||
rm $pem
|
||||
}
|
||||
|
||||
gen_client_batch()
|
||||
{
|
||||
batch=$(($1 / 100))
|
||||
sync_count=$(($1 / 10))
|
||||
baseline=$(ps aux | grep openssl | wc -l)
|
||||
for (( c=1; c<=$1; c++ ))
|
||||
do
|
||||
mac=$(rand_mac)
|
||||
gen_cert client $mac $CLIENT_DIR/$mac.key $CLIENT_DIR/$mac.crt &
|
||||
if [ "$(( $c % $batch ))" -eq "0" ]
|
||||
then
|
||||
echo $(($c/$batch))%
|
||||
fi
|
||||
if [ "$(( $c % $sync_count ))" -eq "0" ]
|
||||
then
|
||||
until [ $(ps aux | grep openssl | wc -l) -eq "$baseline" ]; do sleep 1; done
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
gen_client()
|
||||
{
|
||||
for x in $(seq $1)
|
||||
do
|
||||
echo $x
|
||||
mac=$(rand_mac)
|
||||
gen_cert client $mac $CLIENT_DIR/$mac.key $CLIENT_DIR/$mac.crt
|
||||
done
|
||||
}
|
||||
|
||||
while getopts "ac:shm:o" arg; do
|
||||
case $arg in
|
||||
a)
|
||||
GEN_CA=y
|
||||
;;
|
||||
s)
|
||||
GEN_SER=y
|
||||
;;
|
||||
c)
|
||||
GEN_CLI=y
|
||||
NUM_CERTS=$OPTARG
|
||||
if [ $NUM_CERTS -lt 100 ]
|
||||
then
|
||||
METHOD_FAST=n
|
||||
fi
|
||||
;;
|
||||
m)
|
||||
TEMPLATE=$OPTARG
|
||||
;;
|
||||
o)
|
||||
METHOD_FAST=n
|
||||
;;
|
||||
h)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ "$GEN_CA" == "y" ]
|
||||
then
|
||||
echo Generating root CA certificate
|
||||
mkdir -p $CA_DIR
|
||||
openssl req -config $CONF_FILE -x509 -nodes -newkey rsa:4096 -sha512 -days 365 \
|
||||
-extensions ca -subj "/CN=CA" -out $CA_CERT -keyout $CA_KEY &> /dev/null
|
||||
fi
|
||||
|
||||
if [ "$GEN_SER" == "y" ]
|
||||
then
|
||||
mkdir -p $SERVER_DIR
|
||||
echo Generating server certificate
|
||||
gen_cert server localhost $SERVER_DIR/gw.key $SERVER_DIR/gw.crt
|
||||
fi
|
||||
|
||||
if [ "$GEN_CLI" == "y" ]
|
||||
then
|
||||
echo Generating $NUM_CERTS client certificates
|
||||
mkdir -p $CLIENT_DIR
|
||||
if [ $METHOD_FAST == "y" ]
|
||||
then
|
||||
# because of race condition some of the certificates might fail to generate
|
||||
# but this is ~15 times faster than generating certificates one by one
|
||||
gen_client_batch $NUM_CERTS
|
||||
else
|
||||
gen_client $NUM_CERTS
|
||||
fi
|
||||
fi
|
||||
|
||||
echo done
|
||||
6
utils/client_simulator/.gitignore
vendored
Normal file
6
utils/client_simulator/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
bin/
|
||||
lib/
|
||||
lib64
|
||||
include/
|
||||
*.cfg
|
||||
__pycache__
|
||||
10
utils/client_simulator/Dockerfile
Normal file
10
utils/client_simulator/Dockerfile
Normal file
@@ -0,0 +1,10 @@
|
||||
FROM ubuntu
|
||||
|
||||
COPY ./requirements.txt /tmp/requirements.txt
|
||||
WORKDIR /opt/client_simulator
|
||||
|
||||
RUN apt-get update -q -y && apt-get -q -y --no-install-recommends install \
|
||||
python3 \
|
||||
python3-pip
|
||||
RUN python3 -m pip install -r /tmp/requirements.txt
|
||||
|
||||
34
utils/client_simulator/Makefile
Normal file
34
utils/client_simulator/Makefile
Normal file
@@ -0,0 +1,34 @@
|
||||
IMG_NAME=cgw-client-sim
|
||||
CONTAINER_NAME=cgw_client_sim
|
||||
MAC?=XX:XX:XX:XX:XX:XX
|
||||
COUNT?=1000
|
||||
URL=wss://localhost:15002
|
||||
CA_CERT_PATH?=$(PWD)/../cert_generator/certs/ca
|
||||
CLIENT_CERT_PATH?=$(PWD)/../cert_generator/certs/client
|
||||
MSG_INTERVAL?=10
|
||||
MSG_SIZE?=1000
|
||||
|
||||
.PHONY: build spawn stop start
|
||||
|
||||
build:
|
||||
docker build -t ${IMG_NAME} .
|
||||
|
||||
spawn:
|
||||
docker run --name "${CONTAINER_NAME}_${COUNT}_$(subst :,-,$(MAC))" \
|
||||
-d --rm --network host \
|
||||
-v $(PWD):/opt/client_simulator \
|
||||
-v ${CA_CERT_PATH}:/etc/ca \
|
||||
-v ${CLIENT_CERT_PATH}:/etc/certs \
|
||||
${IMG_NAME} \
|
||||
python3 main.py -M ${MAC} -N ${COUNT} -s ${URL} \
|
||||
--ca-cert /etc/ca/ca.crt \
|
||||
--client-certs-path /etc/certs \
|
||||
--msg-interval ${MSG_INTERVAL} \
|
||||
--payload-size ${MSG_SIZE} \
|
||||
--wait-for-signal
|
||||
|
||||
stop:
|
||||
docker stop $$(docker ps -q -f name=$(CONTAINER_NAME))
|
||||
|
||||
start:
|
||||
docker kill -s SIGUSR1 $$(docker ps -q -f name=$(CONTAINER_NAME))
|
||||
39
utils/client_simulator/README.md
Normal file
39
utils/client_simulator/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Run client simulation
|
||||
|
||||
```
|
||||
# run 10 concurrent client simulations
|
||||
$ ./main.py -s wss://localhost:50001 -n 10
|
||||
|
||||
# use only specified MAC addrs
|
||||
$ ./main.py -s wss://localhost:15002 -N 10 -M AA:XX:XX:XX:XX:XX
|
||||
|
||||
# run 10 concurrent simulations with MAC AA:* and 10 concurrent simulations with MAC BB:*
|
||||
# AA:* and BB:* simulations are run in separate processes
|
||||
$ ./main.py -s wss://localhost:15002 -N 10 -M AA:XX:XX:XX:XX:XX -M BB:XX:XX:XX:XX:XX
|
||||
```
|
||||
|
||||
To stop the simulation use `Ctrl+C`.
|
||||
|
||||
# Run simulation in docker
|
||||
|
||||
```
|
||||
$ make
|
||||
$ make spawn
|
||||
$ make start
|
||||
$ make stop
|
||||
|
||||
# specify mac addr range
|
||||
$ make spawn MAC=11:22:AA:BB:XX:XX
|
||||
|
||||
# specify number of client connections (default is 1000)
|
||||
$ make spawn COUNT=100
|
||||
|
||||
# specify server url
|
||||
$ make spawn URL=wss://localhost:15002
|
||||
|
||||
# tell all running containers to start connecting to the server
|
||||
$ make start
|
||||
|
||||
# stop all containers
|
||||
$ make stop
|
||||
```
|
||||
1
utils/client_simulator/certs
Symbolic link
1
utils/client_simulator/certs
Symbolic link
@@ -0,0 +1 @@
|
||||
../cert_generator/certs/
|
||||
5
utils/client_simulator/data/message_templates.json
Normal file
5
utils/client_simulator/data/message_templates.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{"connect": {"jsonrpc":"2.0","method":"connect","params":{"serial":"MAC","firmware":"Rel 1.6 build 1","uuid":1692198868,"capabilities":{"compatible":"x86_64-kvm_x86_64-r0","model":"DellEMC-S5248f-P-25G-DPB","platform":"switch","label_macaddr":"MAC"}}},
|
||||
"state": {"jsonrpc": "2.0", "method": "state", "params": {"serial": "MAC","uuid": 1692198868, "request_uuid": null, "state": {}}},
|
||||
"reboot_response": {"jsonrpc": "2.0", "result": {"serial": "MAC", "status": {"error": 0, "text": "", "when": 0}, "id": "ID"}},
|
||||
"log": {"jsonrpc": "2.0", "method": "log", "params": {"serial": "MAC", "log": "", "severity": 7, "data": {}}}
|
||||
}
|
||||
8
utils/client_simulator/main.py
Executable file
8
utils/client_simulator/main.py
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
from src.utils import parse_args
|
||||
from src.simulation_runner import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
1
utils/client_simulator/requirements.txt
Normal file
1
utils/client_simulator/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
websockets==12.0
|
||||
0
utils/client_simulator/src/__init__.py
Normal file
0
utils/client_simulator/src/__init__.py
Normal file
57
utils/client_simulator/src/log.py
Normal file
57
utils/client_simulator/src/log.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import logging
|
||||
|
||||
|
||||
TRACE_LEVEL = logging.DEBUG - 5
|
||||
TRACE_NAME = "TRACE"
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
white = "\x1b[37;20m"
|
||||
cyan = "\x1b[36;20m"
|
||||
reset = "\x1b[0m"
|
||||
format = "{asctime}|{levelname}|{threadName}|{funcName}:{lineno}\t{message}"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + format + reset,
|
||||
logging.INFO: cyan + format + reset,
|
||||
logging.WARNING: yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset,
|
||||
TRACE_LEVEL: white + format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, style="{")
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def __trace(self, msg, *args, **kwargs):
|
||||
"""
|
||||
Log 'msg % args' with severity 'TRACE'.
|
||||
|
||||
logger.trace("periodic log")
|
||||
"""
|
||||
if self.isEnabledFor(TRACE_LEVEL):
|
||||
self._log(TRACE_LEVEL, msg, args, **kwargs)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, datefmt="%H:%M:%S")
|
||||
|
||||
logging.addLevelName(TRACE_LEVEL, TRACE_NAME)
|
||||
setattr(logging, TRACE_NAME, TRACE_LEVEL)
|
||||
setattr(logging.getLoggerClass(), TRACE_NAME.lower(), __trace)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.NOTSET)
|
||||
console.setFormatter(ColoredFormatter())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.propagate = False
|
||||
logger.addHandler(console)
|
||||
logging.getLogger('websockets.client').setLevel(logging.INFO)
|
||||
203
utils/client_simulator/src/simulation_runner.py
Normal file
203
utils/client_simulator/src/simulation_runner.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python3
|
||||
from .utils import get_msg_templates, Args
|
||||
from .log import logger
|
||||
from websockets.sync import client
|
||||
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError, ConnectionClosed
|
||||
from typing import List
|
||||
import multiprocessing
|
||||
import threading
|
||||
import resource
|
||||
import string
|
||||
import random
|
||||
import signal
|
||||
import copy
|
||||
import time
|
||||
import json
|
||||
import ssl
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mac: str, size: int):
|
||||
self.templates = get_msg_templates()
|
||||
self.connect = json.dumps(self.templates["connect"]).replace("MAC", mac)
|
||||
self.state = json.dumps(self.templates["state"]).replace("MAC", mac)
|
||||
self.reboot_response = json.dumps(self.templates["reboot_response"]).replace("MAC", mac)
|
||||
self.log = copy.deepcopy(self.templates["log"])
|
||||
self.log["params"]["data"] = {"msg": ''.join(random.choices(string.ascii_uppercase + string.digits, k=size))}
|
||||
self.log = json.dumps(self.log).replace("MAC", mac)
|
||||
|
||||
@staticmethod
|
||||
def to_json(msg) -> str:
|
||||
return json.dumps(msg)
|
||||
|
||||
@staticmethod
|
||||
def from_json(msg) -> dict:
|
||||
return json.loads(msg)
|
||||
|
||||
|
||||
class Device:
|
||||
def __init__(self, mac: str, server: str, ca_cert: str,
|
||||
msg_interval: int, msg_size: int,
|
||||
client_cert: str, client_key: str,
|
||||
start_event: multiprocessing.Event,
|
||||
stop_event: multiprocessing.Event):
|
||||
self.mac = mac
|
||||
self.interval = msg_interval
|
||||
self.messages = Message(self.mac, msg_size)
|
||||
self.server_addr = server
|
||||
self.start_event = start_event
|
||||
self.stop_event = stop_event
|
||||
self.reboot_time_s = 10
|
||||
self._socket = None
|
||||
|
||||
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
self.ssl_context.load_cert_chain(client_cert, client_key, "")
|
||||
self.ssl_context.load_verify_locations(ca_cert)
|
||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
|
||||
def send_hello(self, socket: client.ClientConnection):
|
||||
logger.debug(self.messages.connect)
|
||||
socket.send(self.messages.connect)
|
||||
|
||||
def send_log(self, socket: client.ClientConnection):
|
||||
socket.send(self.messages.log)
|
||||
|
||||
def handle_messages(self, socket: client.ClientConnection):
|
||||
try:
|
||||
msg = socket.recv(self.interval)
|
||||
msg = self.messages.from_json(msg)
|
||||
logger.info(msg)
|
||||
if msg["method"] == "reboot":
|
||||
self.handle_reboot(socket, msg)
|
||||
else:
|
||||
logger.error(f"Unknown method {msg['method']}")
|
||||
except TimeoutError: # no messages
|
||||
pass
|
||||
except (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed):
|
||||
logger.critical("Did not expect socket to be closed")
|
||||
raise
|
||||
|
||||
def handle_reboot(self, socket: client.ClientConnection, msg: dict):
|
||||
resp = self.messages.from_json(self.messages.reboot_response)
|
||||
if "id" in msg:
|
||||
resp["result"]["id"] = msg["id"]
|
||||
else:
|
||||
del resp["result"]["id"]
|
||||
logger.warn("Reboot request is missing 'id' field")
|
||||
socket.send(self.messages.to_json(resp))
|
||||
self.disconnect()
|
||||
time.sleep(self.reboot_time_s)
|
||||
self.connect()
|
||||
self.send_hello(self._socket)
|
||||
|
||||
def connect(self):
|
||||
if self._socket is None:
|
||||
self._socket = client.connect(self.server_addr, ssl_context=self.ssl_context, open_timeout=7200)
|
||||
return self._socket
|
||||
|
||||
def disconnect(self):
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
|
||||
def job(self):
|
||||
logger.debug("waiting for start trigger")
|
||||
self.start_event.wait()
|
||||
if self.stop_event.is_set():
|
||||
return
|
||||
logger.debug("starting simulation")
|
||||
self.connect()
|
||||
start = time.time()
|
||||
try:
|
||||
self.send_hello(self._socket)
|
||||
while not self.stop_event.is_set():
|
||||
if self._socket is None:
|
||||
logger.error("Connection to GW is lost. Trying to reconnect...")
|
||||
self.connect()
|
||||
if time.time() - start > self.interval:
|
||||
logger.info(f"Sent log")
|
||||
self.send_log(self._socket)
|
||||
start = time.time()
|
||||
self.handle_messages(self._socket)
|
||||
finally:
|
||||
self.disconnect()
|
||||
logger.debug("simulation done")
|
||||
|
||||
|
||||
def get_avail_mac_addrs(path, mask="XX:XX:XX:XX:XX:XX"):
|
||||
_mask = "".join(("[0-9a-fA-F]" if c == "X" else c) for c in mask.upper())
|
||||
certs = sorted(os.listdir(path))
|
||||
macs = set(cert.split(".")[0] for cert in certs if "crt" in cert and re.match(_mask, cert))
|
||||
return list(macs)
|
||||
|
||||
|
||||
def update_fd_limit():
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
try:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
||||
except ValueError:
|
||||
logger.critical("Failed to update fd limit")
|
||||
raise
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
logger.warning(f"changed fd limit {soft, hard}")
|
||||
|
||||
|
||||
def process(args: Args, mask: str, start_event: multiprocessing.Event, stop_event: multiprocessing.Event):
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore Ctrl+C in child processes
|
||||
threading.current_thread().name = mask
|
||||
logger.info(f"process started")
|
||||
macs = get_avail_mac_addrs(args.cert_path, mask)
|
||||
if len(macs) < args.number_of_connections:
|
||||
logger.warn(f"expected {args.number_of_connections} certificates, but only found {len(macs)} "
|
||||
f"({mask = })")
|
||||
update_fd_limit()
|
||||
|
||||
devices = [Device(mac, args.server, args.ca_path, args.msg_interval, args.msg_size,
|
||||
os.path.join(args.cert_path, f"{mac}.crt"),
|
||||
os.path.join(args.cert_path, f"{mac}.key"),
|
||||
start_event, stop_event)
|
||||
for mac, _ in zip(macs, range(args.number_of_connections))]
|
||||
threads = [threading.Thread(target=d.job, name=d.mac) for d in devices]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
def verify_cert_availability(cert_path: str, masks: List[str], count: int):
|
||||
for mask in masks:
|
||||
macs = get_avail_mac_addrs(cert_path, mask)
|
||||
assert len(macs) >= count, \
|
||||
f"Simulation requires {count} certificates, but only found {len(macs)}"
|
||||
|
||||
|
||||
def trigger_start(evt):
|
||||
def fn(signum, frame):
|
||||
logger.info("Signal received, starting simulation...")
|
||||
evt.set()
|
||||
return fn
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
verify_cert_availability(args.cert_path, args.masks, args.number_of_connections)
|
||||
stop_event = multiprocessing.Event()
|
||||
start_event = multiprocessing.Event()
|
||||
if not args.wait_for_sig:
|
||||
start_event.set()
|
||||
signal.signal(signal.SIGUSR1, trigger_start(start_event))
|
||||
processes = [multiprocessing.Process(target=process, args=(args, mask, start_event, stop_event))
|
||||
for mask in args.masks]
|
||||
try:
|
||||
for p in processes:
|
||||
p.start()
|
||||
time.sleep(1)
|
||||
logger.info(f"Started {len(processes)} processes")
|
||||
if args.wait_for_sig:
|
||||
logger.info("Waiting for SIGUSR1...")
|
||||
while True:
|
||||
time.sleep(100)
|
||||
except KeyboardInterrupt:
|
||||
logger.warn("Stopping all processes...")
|
||||
stop_event.set()
|
||||
start_event.set()
|
||||
[p.join() for p in processes]
|
||||
119
utils/client_simulator/src/utils.py
Normal file
119
utils/client_simulator/src/utils.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import argparse
|
||||
import random
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
TEMPLATE_LOCATION = "./data/message_templates.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
number_of_connections: int
|
||||
masks: List[str]
|
||||
ca_path: str
|
||||
cert_path: str
|
||||
msg_size: int
|
||||
msg_interval: int
|
||||
wait_for_sig: bool
|
||||
server_proto: str = "ws"
|
||||
server_address: str = "localhost"
|
||||
server_port: int = 50001
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
return f"{self.server_proto}://{self.server_address}:{self.server_port}"
|
||||
|
||||
|
||||
def parse_msg_size(input: str) -> int:
|
||||
match = re.match(r"^(\d+)([kKmM]?)$", input)
|
||||
if match is None:
|
||||
raise ValueError(f"Unable to parse message size \"{input}\"")
|
||||
num, prefix = match.groups()
|
||||
num = int(num)
|
||||
if prefix and prefix in "kK":
|
||||
num *= 1000
|
||||
elif prefix and prefix in "mM":
|
||||
num *= 1000000
|
||||
return num
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Used to simulate multiple clients that connect to a single server.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("-s", "--server", metavar="ADDRESS", required=True,
|
||||
default="ws://localhost:50001",
|
||||
help="server address")
|
||||
parser.add_argument("-N", "--number-of-connections", metavar="NUMBER", type=int,
|
||||
default=1,
|
||||
help="number of concurrent connections per thread pool")
|
||||
parser.add_argument("-M", "--mac-mask", metavar="XX:XX:XX:XX:XX:XX", action="append",
|
||||
default=[],
|
||||
help="the mask determines what MAC addresses will be used by clients "
|
||||
"in the thread pool. Specifying multiple masks will increase "
|
||||
"the number of thread pools used.")
|
||||
parser.add_argument("-a", "--ca-cert", metavar="CERT",
|
||||
default="./certs/ca/ca.crt",
|
||||
help="path to CA certificate")
|
||||
parser.add_argument("-c", "--client-certs-path", metavar="PATH",
|
||||
default="./certs/client",
|
||||
help="path to client certificates directory")
|
||||
parser.add_argument("-t", "--msg-interval", metavar="SECONDS", type=int,
|
||||
default=10,
|
||||
help="time between client messages to gw")
|
||||
parser.add_argument("-p", "--payload-size", metavar="SIZE", type=str,
|
||||
default="1k",
|
||||
help="size of each client message")
|
||||
parser.add_argument("-w", "--wait-for-signal", action="store_true",
|
||||
help="wait for SIGUSR1 before running simulation")
|
||||
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
args = Args(number_of_connections=parsed_args.number_of_connections,
|
||||
masks=parsed_args.mac_mask,
|
||||
ca_path=parsed_args.ca_cert,
|
||||
cert_path=parsed_args.client_certs_path,
|
||||
msg_interval=parsed_args.msg_interval,
|
||||
msg_size=parse_msg_size(parsed_args.payload_size),
|
||||
wait_for_sig=parsed_args.wait_for_signal)
|
||||
|
||||
if len(args.masks) == 0:
|
||||
args.masks.append("XX:XX:XX:XX:XX:XX")
|
||||
|
||||
# PROTO :// ADDRESS : PORT
|
||||
match = re.match(r"(?:(wss?)://)?([\d\w\.]+):?(\d+)?", parsed_args.server)
|
||||
if match is None:
|
||||
raise ValueError(f"Unable to parse server address {parsed_args.server}")
|
||||
proto, addr, port = match.groups()
|
||||
if proto is not None:
|
||||
args.server_proto = proto
|
||||
if addr is not None:
|
||||
args.server_address = addr
|
||||
if port is not None:
|
||||
args.server_port = port
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def rand_mac(mask="02:xx:xx:xx:xx:xx"):
|
||||
return ''.join([n.lower().replace('x', f'{random.randint(0, 15):x}') for n in mask])
|
||||
|
||||
|
||||
def get_msg_templates():
|
||||
with open(TEMPLATE_LOCATION, "r") as templates:
|
||||
return json.loads(templates.read())
|
||||
|
||||
|
||||
def gen_certificates(mask: str, count=int):
|
||||
cwd = os.getcwd()
|
||||
os.chdir("../cert_generator")
|
||||
try:
|
||||
rc = os.system(f"./generate_certs.sh -c {count} -m \"{mask}\" -o")
|
||||
assert rc == 0, "Generating certificates failed"
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
7
utils/kafka_producer/.gitignore
vendored
Normal file
7
utils/kafka_producer/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
bin/
|
||||
lib/
|
||||
lib64
|
||||
include/
|
||||
share/
|
||||
*.cfg
|
||||
__pycache__
|
||||
33
utils/kafka_producer/data/message_template.json
Normal file
33
utils/kafka_producer/data/message_template.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"add_group": {
|
||||
"type": "infrastructure_group_create",
|
||||
"infra_group_id": "key",
|
||||
"infra_name": "name",
|
||||
"infra_shard_id": 0,
|
||||
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
|
||||
},
|
||||
"del_group": {
|
||||
"type": "infrastructure_group_delete",
|
||||
"infra_group_id": "key",
|
||||
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
|
||||
},
|
||||
"add_to_group": {
|
||||
"type": "infrastructure_group_device_add",
|
||||
"infra_group_id": "key",
|
||||
"infra_group_infra_devices": [],
|
||||
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
|
||||
},
|
||||
"del_from_group": {
|
||||
"type": "infrastructure_group_device_del",
|
||||
"infra_group_id": "key",
|
||||
"infra_group_infra_devices": [],
|
||||
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
|
||||
},
|
||||
"message_device": {
|
||||
"type": "infrastructure_group_device_message",
|
||||
"infra_group_id": "key",
|
||||
"mac": "mac",
|
||||
"msg": {},
|
||||
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
|
||||
}
|
||||
}
|
||||
23
utils/kafka_producer/main.py
Executable file
23
utils/kafka_producer/main.py
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
from src.cli_parser import parse_args, Args
|
||||
from src.producer import Producer
|
||||
from src.log import logger
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
producer = Producer(args.db, args.topic)
|
||||
if args.add_groups or args.del_groups:
|
||||
producer.handle_group_creation(args.add_groups, args.del_groups)
|
||||
if args.assign_to_group or args.remove_from_group:
|
||||
producer.handle_device_assignment(args.assign_to_group, args.remove_from_group)
|
||||
if args.message:
|
||||
producer.handle_device_messages(args.message, args.group_id, args.send_to_macs,
|
||||
args.count, args.time_to_send_s, args.interval_s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
args = parse_args()
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
logger.warn("exiting...")
|
||||
1
utils/kafka_producer/requirements.txt
Normal file
1
utils/kafka_producer/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
kafka-python==2.0.2
|
||||
0
utils/kafka_producer/src/__init__.py
Normal file
0
utils/kafka_producer/src/__init__.py
Normal file
116
utils/kafka_producer/src/cli_parser.py
Normal file
116
utils/kafka_producer/src/cli_parser.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from .utils import MacRange, Args
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def time(input: str) -> float:
|
||||
"""
|
||||
Expects a string in the format:
|
||||
*number*
|
||||
*number*s
|
||||
*number*m
|
||||
*number*h
|
||||
|
||||
Returns the number of seconds as an integer. Example:
|
||||
100 -> 100
|
||||
100s -> 100
|
||||
5m -> 300
|
||||
2h -> 7200
|
||||
"""
|
||||
if all(char not in input for char in "smh"):
|
||||
return float(input)
|
||||
n, t = float(input[:-1]), input[-1:]
|
||||
if t == "h":
|
||||
return n * 60 * 60
|
||||
if t == "m":
|
||||
return n * 60
|
||||
return n
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Creates entries in kafka.")
|
||||
|
||||
parser.add_argument("-g", "--new-group", metavar=("GROUP-ID", "SHARD-ID", "NAME"),
|
||||
nargs=3, action="append",
|
||||
help="create a new group")
|
||||
parser.add_argument("-G", "--rm-group", metavar=("GROUP-ID"),
|
||||
nargs=1, action="append",
|
||||
help="delete an existing group")
|
||||
parser.add_argument("-d", "--assign-to-group", metavar=("GROUP-ID", "MAC-RANGE"),
|
||||
nargs=2, action="append",
|
||||
help="add a range of mac addrs to a group")
|
||||
parser.add_argument("-D", "--remove-from-group", metavar=("GROUP-ID", "MAC-RANGE"),
|
||||
nargs=2, action="append",
|
||||
help="remove mac addrs from a group")
|
||||
parser.add_argument("-T", "--topic", default="CnC",
|
||||
help="kafka topic (default: \"CnC\")")
|
||||
parser.add_argument("-s", "--bootstrap-server", metavar="ADDRESS", default="172.20.10.136:9092",
|
||||
help="kafka address (default: \"172.20.10.136:9092\")")
|
||||
parser.add_argument("-m", "--send-message", metavar="JSON", type=str,
|
||||
help="this message will be sent down from the GW to all devices "
|
||||
"specified in the --send-to-mac")
|
||||
parser.add_argument("-c", "--send-count", metavar="COUNT", type=int,
|
||||
help="how many messages will be sent (per mac address, default: 1)")
|
||||
parser.add_argument("-t", "--send-for", metavar="TIME", type=time,
|
||||
help="how long to send the messages for")
|
||||
parser.add_argument("-i", "--send-interval", metavar="INTERVAL", type=time, default="1",
|
||||
help="time between messages (default: \"1.0s\")")
|
||||
parser.add_argument("-p", "--send-to-group", metavar="GROUP-ID", type=str)
|
||||
parser.add_argument("-r", "--send-to-mac", metavar="MAC-RANGE", type=MacRange,
|
||||
help="range of mac addrs that will be receiving the messages")
|
||||
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
if parsed_args.send_message is not None and (
|
||||
parsed_args.send_to_group is None or
|
||||
parsed_args.send_to_mac is None
|
||||
):
|
||||
parser.error("--send-message requires --send-to-group and --send-to-mac")
|
||||
|
||||
message = None
|
||||
if parsed_args.send_message is not None:
|
||||
try:
|
||||
message = json.loads(parsed_args.send_message)
|
||||
except json.JSONDecodeError:
|
||||
parser.error("--send-message must be in JSON format")
|
||||
|
||||
count = 1
|
||||
if parsed_args.send_count is not None:
|
||||
count = parsed_args.send_count
|
||||
elif parsed_args.send_for is not None:
|
||||
count = 0
|
||||
|
||||
args = Args(
|
||||
[], [], [], [],
|
||||
topic=parsed_args.topic,
|
||||
db=parsed_args.bootstrap_server,
|
||||
message=message,
|
||||
count=count,
|
||||
time_to_send_s=parsed_args.send_for,
|
||||
interval_s=parsed_args.send_interval,
|
||||
group_id=parsed_args.send_to_group,
|
||||
send_to_macs=parsed_args.send_to_mac,
|
||||
)
|
||||
if parsed_args.new_group is not None:
|
||||
for group, shard, name in parsed_args.new_group:
|
||||
try:
|
||||
args.add_groups.append((group, int(shard), name))
|
||||
except ValueError:
|
||||
parser.error(f"--new-group: failed to parse shard id \"{shard}\"")
|
||||
if parsed_args.rm_group is not None:
|
||||
for (group,) in parsed_args.rm_group:
|
||||
args.del_groups.append(group)
|
||||
if parsed_args.assign_to_group is not None:
|
||||
try:
|
||||
for group, mac in parsed_args.assign_to_group:
|
||||
args.assign_to_group.append((group, MacRange(mac)))
|
||||
except ValueError:
|
||||
parser.error(f"--assign-to-group: failed to parse MAC range \"{mac}\"")
|
||||
if parsed_args.remove_from_group is not None:
|
||||
try:
|
||||
for group, mac in parsed_args.remove_from_group:
|
||||
args.remove_from_group.append((group, MacRange(mac)))
|
||||
except ValueError:
|
||||
parser.error(f"--remove-from-group: failed to parse MAC range \"{mac}\"")
|
||||
|
||||
return args
|
||||
37
utils/kafka_producer/src/log.py
Normal file
37
utils/kafka_producer/src/log.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import logging
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
white = "\x1b[37;20m"
|
||||
cyan = "\x1b[36;20m"
|
||||
reset = "\x1b[0m"
|
||||
format = "{asctime}|{levelname}|{threadName}|{funcName}:{lineno}\t{message}"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + format + reset,
|
||||
logging.INFO: cyan + format + reset,
|
||||
logging.WARNING: yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, style="{")
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, datefmt="%H:%M:%S")
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.NOTSET)
|
||||
console.setFormatter(ColoredFormatter())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.propagate = False
|
||||
logger.addHandler(console)
|
||||
74
utils/kafka_producer/src/producer.py
Normal file
74
utils/kafka_producer/src/producer.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from .utils import Message, MacRange
|
||||
from .log import logger
|
||||
|
||||
from typing import List, Tuple
|
||||
import kafka
|
||||
import time
|
||||
import uuid
|
||||
import sys
|
||||
|
||||
|
||||
class Producer:
|
||||
def __init__(self, db: str, topic: str) -> None:
|
||||
self.db = db
|
||||
self.conn = None
|
||||
self.topic = topic
|
||||
self.message = Message()
|
||||
|
||||
def __enter__(self) -> kafka.KafkaProducer:
|
||||
return self.connect()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.disconnect()
|
||||
|
||||
def connect(self) -> kafka.KafkaProducer:
|
||||
if self.conn is None:
|
||||
self.conn = kafka.KafkaProducer(bootstrap_servers=self.db, client_id="producer")
|
||||
logger.info("connected to kafka")
|
||||
else:
|
||||
logger.info("already connected to kafka")
|
||||
return self.conn
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.conn is None:
|
||||
return
|
||||
self.conn.close()
|
||||
logger.info("disconnected from kafka")
|
||||
self.conn = None
|
||||
|
||||
def handle_group_creation(self, create: List[Tuple[str, int, str]], delete: List[str]) -> None:
|
||||
with self as conn:
|
||||
for group, shard_id, name in create:
|
||||
conn.send(self.topic, self.message.group_create(group, shard_id, name),
|
||||
bytes(group, encoding="utf-8"))
|
||||
for group in delete:
|
||||
conn.send(self.topic, self.message.group_delete(group),
|
||||
bytes(group, encoding="utf-8"))
|
||||
|
||||
def handle_device_assignment(self, add: List[Tuple[str, MacRange]], remove: List[Tuple[str, MacRange]]) -> None:
|
||||
with self as conn:
|
||||
for group, mac_range in add:
|
||||
logger.debug(f"{group = }, {mac_range = }")
|
||||
conn.send(self.topic, self.message.add_dev_to_group(group, mac_range),
|
||||
bytes(group, encoding="utf-8"))
|
||||
for group, mac_range in remove:
|
||||
conn.send(self.topic, self.message.remove_dev_from_group(group, mac_range),
|
||||
bytes(group, encoding="utf-8"))
|
||||
|
||||
def handle_device_messages(self, message: dict, group: str, mac_range: MacRange,
|
||||
count: int, time_s: int, interval_s: int) -> None:
|
||||
if not time_s:
|
||||
end = sys.maxsize
|
||||
else:
|
||||
end = time.time() + time_s
|
||||
if not count:
|
||||
count = sys.maxsize
|
||||
|
||||
with self as conn:
|
||||
for seq in range(count):
|
||||
for mac in mac_range:
|
||||
conn.send(self.topic, self.message.to_device(group, mac, message, seq),
|
||||
bytes(group, encoding="utf-8"))
|
||||
#time.sleep(interval_s)
|
||||
#if time.time() > end:
|
||||
# break
|
||||
152
utils/kafka_producer/src/utils.py
Normal file
152
utils/kafka_producer/src/utils.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
class MacRange:
|
||||
"""
|
||||
Return an object that produces a sequence of MAC addresses from
|
||||
START (inclusive) to END (inclusive). START and END are exctracted
|
||||
from the input string if it is in the format
|
||||
"11:22:AA:BB:00:00-11:22:AA:BB:00:05" (where START=11:22:AA:BB:00:00,
|
||||
END=11:22:AA:BB:00:05, and the total amount of MACs in the range is 6).
|
||||
|
||||
Examples (all of these are identical):
|
||||
|
||||
00:00:00:00:XX:XX
|
||||
00:00:00:00:00:00-00:00:00:00:FF:FF
|
||||
00:00:00:00:00:00^65536
|
||||
|
||||
Raises ValueError
|
||||
"""
|
||||
def __init__(self, input: str = "XX:XX:XX:XX:XX:XX") -> None:
|
||||
self.__base_as_num, self.__len = self.__parse_input(input.upper())
|
||||
self.__idx = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> str:
|
||||
if self.__idx >= len(self):
|
||||
self.__idx = 0
|
||||
raise StopIteration()
|
||||
mac = self.num2mac(self.__base_as_num + self.__idx)
|
||||
self.__idx += 1
|
||||
return mac
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.__len
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MacRange[start={self.base}, " \
|
||||
f"end={self.num2mac(self.__base_as_num + len(self) - 1)}]"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MacRange('{self.base}^{len(self)}')"
|
||||
|
||||
@property
|
||||
def base(self) -> str:
|
||||
return self.num2mac(self.__base_as_num)
|
||||
|
||||
@staticmethod
|
||||
def mac2num(mac: str) -> int:
|
||||
return int(mac.replace(":", ""), base=16)
|
||||
|
||||
@staticmethod
|
||||
def num2mac(mac: int) -> str:
|
||||
hex = f"{mac:012X}"
|
||||
return ":".join([a+b for a, b in zip(hex[::2], hex[1::2])])
|
||||
|
||||
def __parse_input(self, input: str) -> Tuple[int, int]:
|
||||
if "X" in input:
|
||||
string = f"{input.replace('X', '0')}-{input.replace('X', 'F')}"
|
||||
else:
|
||||
string = input
|
||||
if "-" in string:
|
||||
start, end = string.split("-")
|
||||
start, end = self.mac2num(start), self.mac2num(end)
|
||||
if start > end:
|
||||
raise ValueError(f"Invalid MAC range {start}-{end}")
|
||||
return start, end - start + 1
|
||||
if "^" in string:
|
||||
base, count = string.split("^")
|
||||
return self.mac2num(base), int(count)
|
||||
return self.mac2num(input), 1
|
||||
|
||||
|
||||
class Message:
|
||||
TEMPLATE_FILE = "./data/message_template.json"
|
||||
GROUP_ADD = "add_group"
|
||||
GROUP_DEL = "del_group"
|
||||
DEV_TO_GROUP = "add_to_group"
|
||||
DEV_FROM_GROUP = "del_from_group"
|
||||
TO_DEVICE = "message_device"
|
||||
GROUP_ID = "infra_group_id"
|
||||
GROUP_NAME = "infra_name"
|
||||
SHARD_ID = "infra_shard_id"
|
||||
DEV_LIST = "infra_group_infra_devices"
|
||||
MAC = "mac"
|
||||
DATA = "msg"
|
||||
MSG_UUID = "uuid"
|
||||
|
||||
def __init__(self) -> None:
|
||||
with open(self.TEMPLATE_FILE) as f:
|
||||
self.templates = json.loads(f.read())
|
||||
|
||||
def group_create(self, id: str, shard_id: int, name: str) -> bytes:
|
||||
msg = copy.copy(self.templates[self.GROUP_ADD])
|
||||
msg[self.GROUP_ID] = id
|
||||
msg[self.SHARD_ID] = shard_id
|
||||
msg[self.GROUP_NAME] = name
|
||||
msg[self.MSG_UUID] = str(uuid.uuid1())
|
||||
return json.dumps(msg).encode('utf-8')
|
||||
|
||||
def group_delete(self, id: str) -> bytes:
|
||||
msg = copy.copy(self.templates[self.GROUP_DEL])
|
||||
msg[self.GROUP_ID] = id
|
||||
msg[self.MSG_UUID] = str(uuid.uuid1())
|
||||
return json.dumps(msg).encode('utf-8')
|
||||
|
||||
def add_dev_to_group(self, id: str, mac_range: MacRange) -> bytes:
|
||||
msg = copy.copy(self.templates[self.DEV_TO_GROUP])
|
||||
msg[self.GROUP_ID] = id
|
||||
msg[self.DEV_LIST] = list(mac_range)
|
||||
msg[self.MSG_UUID] = str(uuid.uuid1())
|
||||
return json.dumps(msg).encode('utf-8')
|
||||
|
||||
def remove_dev_from_group(self, id: str, mac_range: MacRange) -> bytes:
|
||||
msg = copy.copy(self.templates[self.DEV_FROM_GROUP])
|
||||
msg[self.GROUP_ID] = id
|
||||
msg[self.DEV_LIST] = list(mac_range)
|
||||
msg[self.MSG_UUID] = str(uuid.uuid1())
|
||||
return json.dumps(msg).encode('utf-8')
|
||||
|
||||
def to_device(self, id: str, mac: str, data, sequence: int = 0):
|
||||
msg = copy.copy(self.templates[self.TO_DEVICE])
|
||||
msg[self.GROUP_ID] = id
|
||||
msg[self.MAC] = mac
|
||||
if type(data) is dict:
|
||||
msg[self.DATA] = data
|
||||
else:
|
||||
msg[self.DATA] = {"data": data}
|
||||
msg[self.MSG_UUID] = str(uuid.uuid1(node=MacRange.mac2num(mac), clock_seq=sequence))
|
||||
return json.dumps(msg).encode('utf-8')
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
add_groups: List[Tuple[str, int, str]]
|
||||
del_groups: List[str]
|
||||
assign_to_group: List[Tuple[str, MacRange]]
|
||||
remove_from_group: List[Tuple[str, MacRange]]
|
||||
topic: str
|
||||
db: str
|
||||
message: dict
|
||||
count: int
|
||||
time_to_send_s: float
|
||||
interval_s: float
|
||||
group_id: int
|
||||
send_to_macs: MacRange
|
||||
Reference in New Issue
Block a user