cgw: initial commit

Signed-off-by: Paul White <paul@shasta.cloud>
This commit is contained in:
Paul White
2024-03-17 20:33:16 -07:00
parent aa54d00765
commit 6a0344aa87
38 changed files with 4083 additions and 0 deletions

36
Cargo.toml Normal file
View File

@@ -0,0 +1,36 @@
[package]
name = "ucentral-cgw"
version = "0.0.1"
edition = "2021"
[dependencies]
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
env_logger = "0.5.0"
log = "0.4.20"
tokio = { version = "1.34.0", features = ["full"] }
tokio-stream = { version = "*", features = ["full"] }
tokio-tungstenite = { version = "*", features = ["native-tls"] }
tokio-native-tls = "*"
tokio-rustls = "*"
tokio-postgres = { version = "0.7.10", features = ["with-eui48-1"]}
tokio-pg-mapper = "0.2.0"
tungstenite = { version = "*"}
mio = "0.6.10"
native-tls = "*"
futures-util = { version = "0.3.0", default-features = false }
futures-channel = "0.3.0"
futures-executor = { version = "0.3.0", optional = true }
futures = "0.3.0"
rlimit = "0.10.1"
tonic = "0.10.2"
prost = "0.12"
rdkafka = "0.35.0"
clap = { version = "4.4.8", features = ["derive"] }
eui48 = "1.1.0"
uuid = { version = "1.6.1", features = ["serde"] }
redis-async = "0.16.1"
[build-dependencies]
tonic-build = "0.10"
prost-build = "0.12"

1
README.md Normal file
View File

@@ -0,0 +1 @@
# openlan-cgw

4
build.rs Normal file
View File

@@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("src/proto/cgw.proto")?;
Ok(())
}

View File

@@ -0,0 +1,531 @@
use crate::cgw_connection_server::{
CGWConnectionServer,
CGWConnectionServerReqMsg,
};
use tokio::{
sync::{
mpsc::{
UnboundedReceiver,
unbounded_channel,
},
},
net::{
TcpStream,
},
time::{
sleep,
Duration,
Instant,
},
};
use tokio_native_tls::TlsStream;
use tokio_tungstenite::{
WebSocketStream,
tungstenite::{
protocol::{
Message,
},
},
};
use tungstenite::Message::{
Close,
Text,
Ping,
};
use futures_util::{
StreamExt,
SinkExt,
FutureExt,
stream::{
SplitSink,
SplitStream
},
};
use std::{
net::SocketAddr,
sync::{
Arc,
},
};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
type CGWUcentralJRPCMessage = Map<String, Value>;
type SStream = SplitStream<WebSocketStream<TlsStream<TcpStream>>>;
type SSink = SplitSink<WebSocketStream<TlsStream<TcpStream>>, Message>;
#[derive(Debug)]
pub enum CGWConnectionProcessorReqMsg {
// We got green light from server to process this connection on
AddNewConnectionAck,
AddNewConnectionShouldClose,
SinkRequestToDevice(String, String),
}
#[derive(Debug)]
enum CGWConnectionState {
IsActive,
IsForcedToClose,
IsDead,
IsStale,
ClosedGracefully,
}
#[derive(Deserialize, Serialize, Debug, Default)]
struct CGWEventLogParams {
serial: String,
log: String,
severity: i64,
}
#[derive(Deserialize, Serialize, Debug, Default)]
struct CGWEventLog {
params: CGWEventLogParams,
}
#[derive(Deserialize, Serialize, Debug, Default)]
struct CGWEventConnectParamsCaps {
compatible: String,
model: String,
platform: String,
label_macaddr: String,
}
#[derive(Deserialize, Serialize, Debug, Default)]
struct CGWEventConnectParams {
serial: String,
firmware: String,
uuid: u64,
capabilities: CGWEventConnectParamsCaps,
}
#[derive(Deserialize, Serialize, Debug, Default)]
struct CGWEventConnect {
params: CGWEventConnectParams,
}
#[derive(Deserialize, Serialize, Debug)]
enum CGWEvent {
Connect(CGWEventConnect),
Log(CGWEventLog),
Empty,
}
fn cgw_parse_jrpc_event(map: &Map<String, Value>, method: String) -> CGWEvent {
if method == "log" {
let params = map.get("params").expect("Params are missing");
return CGWEvent::Log(CGWEventLog {
params: CGWEventLogParams {
serial: params["serial"].to_string(),
log: params["log"].to_string(),
severity: serde_json::from_value(params["severity"].clone()).unwrap(),
},
});
} else if method == "connect" {
let params = map.get("params").expect("Params are missing");
return CGWEvent::Connect(CGWEventConnect {
params: CGWEventConnectParams {
serial: params["serial"].to_string(),
firmware: params["firmware"].to_string(),
uuid: 1,
capabilities: CGWEventConnectParamsCaps {
compatible: params["capabilities"]["compatible"].to_string(),
model: params["capabilities"]["model"].to_string(),
platform: params["capabilities"]["platform"].to_string(),
label_macaddr: params["capabilities"]["label_macaddr"].to_string(),
},
},
});
}
CGWEvent::Empty
}
async fn cgw_process_jrpc_event(event: &CGWEvent) -> Result<(), String> {
// TODO
if let CGWEvent::Connect(c) = event {
/*
info!(
"Requesting {} to reboot (immediate request)",
c.params.serial
);
let req = json!({
"jsonrpc": "2.0",
"method": "reboot",
"params": {
"serial": c.params.serial,
"when": 0
},
"id": 1
});
info!("Received connect msg {}", c.params.serial);
sender.send(Message::text(req.to_string())).await.ok();
*/
}
Ok(())
}
// TODO: heavy rework to enum-variant struct-based
async fn cgw_process_jrpc_message(message: Message) -> Result<CGWUcentralJRPCMessage, String> {
//let rpcmsg: CGWMethodConnect = CGWMethodConnect::default();
//serde_json::from_str(method).unwrap();
//
let msg = if let Ok(s) = message.into_text() {
s
} else {
return Err("Message to string cast failed".to_string());
};
let map: CGWUcentralJRPCMessage = match serde_json::from_str(&msg) {
Ok(m) => m,
Err(e) => {
error!("Failed to parse input json {e}");
return Err("Failed to parse input json".to_string());
}
};
//.expect("Failed to parse input json");
if !map.contains_key("jsonrpc") {
warn!("Received malformed JSONRPC msg");
return Err("JSONRPC field is missing in message".to_string());
}
if map.contains_key("method") {
if !map.contains_key("params") {
warn!("Received JRPC <method> without params.");
return Err("Received JRPC <method> without params".to_string());
}
// unwrap can panic
let method = map["method"].as_str().unwrap();
let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string());
match &event {
CGWEvent::Log(l) => {
debug!(
"Received LOG evt from device {}: {}",
l.params.serial, l.params.log
);
}
CGWEvent::Connect(c) => {
debug!(
"Received connect evt from device {}: type {}, fw {}",
c.params.serial, c.params.capabilities.platform, c.params.firmware
);
}
_ => {
warn!("received not yet implemented method {}", method);
return Err(format!("received not yet implemented method {}", method));
}
};
if let Err(e) = cgw_process_jrpc_event(&event).await {
warn!(
"Failed to process jrpc event (unmatched) {}",
method.to_string()
);
return Err(e);
}
// TODO
} else if map.contains_key("result") {
info!("Processing <result> JSONRPC msg");
info!("{:?}", map);
return Err("Result handling is not yet implemented".to_string());
}
/*
match map.get_mut("jsonrpc") {
Some(value) => info!("Got value {:?}", value),
None => info!("Got no value"),
}
*/
/*
if let CGWMethod::Connect { ref someint, .. } = &rpcmsg {
info!("secondmatch {}", *someint);
return Some(rpcmsg);
} else {
return None;
}
*/
//return Some(CGWMethodConnect::default());
Ok(map)
}
pub struct CGWConnectionProcessor {
cgw_server: Arc<CGWConnectionServer>,
pub serial: Option<String>,
pub addr: SocketAddr,
pub idx: i64,
}
impl CGWConnectionProcessor {
pub fn new(server: Arc<CGWConnectionServer>, conn_idx: i64, addr: SocketAddr) -> Self {
let conn_processor : CGWConnectionProcessor = CGWConnectionProcessor {
cgw_server: server,
serial: None,
addr: addr,
idx: conn_idx,
};
conn_processor
}
pub async fn start(mut self, tls_stream: TlsStream<TcpStream>) {
let ws_stream = tokio_tungstenite::accept_async(tls_stream)
.await
.expect("error during the websocket handshake occurred");
let (sink, mut stream) = ws_stream.split();
// check if we have any pending msgs (we expect connect at this point, protocol-wise)
// TODO: rework to ignore any WS-related frames untill we get a connect message,
// however there's a caveat: we can miss some events logs etc from underlying device
// rework should consider all the options
let msg = tokio::select! {
_val = stream.next() => {
match _val {
Some(m) => m,
None => {
error!("no connect message received from {}, closing connection", self.addr);
return;
}
}
}
// TODO: configurable duration (upon server creation)
_val = sleep(Duration::from_millis(30000)) => {
error!("no message received from {}, closing connection", self.addr);
return;
}
};
// we have a next() result, but it still may be undelying io error: check for it
// break connection if we can't work with underlying ws connection (pror err etc)
let message = match msg {
Ok(m) => m,
Err(e) => {
error!("established connection with device, but failed to receive any messages\n{e}");
return;
}
};
let map = match cgw_process_jrpc_message(message).await {
Ok(val) => val,
Err(e) => {
error!("failed to recv connect message from {}, closing connection", self.addr);
return;
}
};
let serial = map["params"]["serial"].as_str().unwrap();
self.serial = Some(serial.to_string());
// TODO: we accepted tls stream and split the WS into RX TX part,
// now we have to ASK cgw_connection_server's permission whether
// we can proceed on with this underlying connection.
// cgw_connection_server has an authorative decision whether
// we can proceed.
let (mbox_tx, mut mbox_rx) = unbounded_channel::<CGWConnectionProcessorReqMsg>();
let msg = CGWConnectionServerReqMsg::AddNewConnection(serial.to_string(), mbox_tx);
self.cgw_server.enqueue_mbox_message_to_cgw_server(msg).await;
let ack = mbox_rx.recv().await;
if let Some(m) = ack {
match m {
CGWConnectionProcessorReqMsg::AddNewConnectionAck => {
debug!("websocket connection established: {} {}", self.addr, serial);
},
_ => panic!("Unexpected response from server, expected ACK/NOT ACK)"),
}
} else {
info!("connection server declined connection, websocket connection {} {} cannot be established",
self.addr, serial);
return;
}
self.process_connection(stream, sink, mbox_rx).await;
}
async fn process_wss_rx_msg(&self, msg: Result<Message, tungstenite::error::Error>) -> Result<CGWConnectionState, &'static str> {
match msg {
Ok(msg) => {
match msg {
Close(_t) => {
return Ok(CGWConnectionState::ClosedGracefully);
},
Text(payload) => {
self.cgw_server.enqueue_mbox_message_from_device_to_nb_api_c(self.serial.clone().unwrap(), payload);
return Ok(CGWConnectionState::IsActive);
},
Ping(_t) => {
return Ok(CGWConnectionState::IsActive);
},
_ => {
}
}
}
Err(e) => {
match e {
tungstenite::error::Error::AlreadyClosed => {
return Err("Underlying connection's been closed");
},
_ => {
}
}
}
}
Ok(CGWConnectionState::IsActive)
}
async fn process_sink_mbox_rx_msg(&self,sink: &mut SSink, val: Option<CGWConnectionProcessorReqMsg>) -> Result<CGWConnectionState, &str> {
if let Some(msg) = val {
let processor_mac = self.serial.clone().unwrap();
match msg {
CGWConnectionProcessorReqMsg::AddNewConnectionShouldClose => {
debug!("MBOX_IN: AddNewConnectionShouldClose, processor (mac:{processor_mac}) (ACK OK)");
return Ok(CGWConnectionState::IsForcedToClose);
},
CGWConnectionProcessorReqMsg::SinkRequestToDevice(mac, pload) => {
debug!("MBOX_IN: SinkRequestToDevice, processor (mac:{processor_mac}) req for (mac:{mac}) payload:{pload}");
sink.send(Message::text(pload)).await.ok();
},
_ => panic!("Unexpected message received {:?}", msg),
}
}
Ok(CGWConnectionState::IsActive)
}
async fn process_stale_connection_msg(&self, last_contact: Instant) -> Result<CGWConnectionState, &str> {
// TODO: configurable duration (upon server creation)
if Instant::now().duration_since(last_contact) > Duration::from_secs(70) {
warn!("Closing connection {} (idle for too long, stale)", self.addr);
Ok(CGWConnectionState::IsStale)
} else {
Ok(CGWConnectionState::IsActive)
}
}
async fn process_connection(
self,
mut stream: SStream,
mut sink: SSink,
mut mbox_rx: UnboundedReceiver<CGWConnectionProcessorReqMsg>) {
#[derive(Debug)]
enum WakeupReason {
Unspecified,
WSSRxMsg(Result<Message, tungstenite::error::Error>),
MboxRx(Option<CGWConnectionProcessorReqMsg>),
Stale,
}
let mut last_contact = Instant::now();
let mut poll_wss_first = true;
// Get underlying wakeup reason and do initial parsion, like:
// - check if WSS stream has a message or an (recv) error
// - check if sinkmbox has a message or an (recv) error
// - check if connection's been stale for X time
//
// TODO: try_next intead of sync .next? could potentially
// skyrocket CPU usage.
loop {
let mut wakeup_reason: WakeupReason = WakeupReason::Unspecified;
// TODO: refactor
// Round-robin selection of stream to process:
// first, get single message from WSS, then get a single msg from RX MBOX
// It's done to ensure we process WSS and RX MBOX equally with same priority
// Also, we have to make sure we don't sleep-wait for any of the streams to
// make sure we don't cancel futures that are used for stream processing,
// especially TCP stream, which is not cancel-safe
if poll_wss_first {
if let Some(val) = stream.next().now_or_never() {
if let Some(res) = val {
if let Ok(msg) = res {
wakeup_reason = WakeupReason::WSSRxMsg(Ok(msg));
} else if let Err(msg) = res {
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(msg));
}
} else if let None = val {
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(tungstenite::error::Error::AlreadyClosed));
}
} else if let Some(val) = mbox_rx.recv().now_or_never() {
wakeup_reason = WakeupReason::MboxRx(val)
}
poll_wss_first = !poll_wss_first;
} else {
if let Some(val) = mbox_rx.recv().now_or_never() {
wakeup_reason = WakeupReason::MboxRx(val)
} else if let Some(val) = stream.next().now_or_never() {
if let Some(res) = val {
if let Ok(msg) = res {
wakeup_reason = WakeupReason::WSSRxMsg(Ok(msg));
} else if let Err(msg) = res {
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(msg));
}
} else if let None = val {
wakeup_reason = WakeupReason::WSSRxMsg(Result::Err(tungstenite::error::Error::AlreadyClosed));
}
}
poll_wss_first = !poll_wss_first;
}
// TODO: somehow workaround the sleeping?
// Both WSS and RX MBOX are empty: chill for a while
if let WakeupReason::Unspecified = wakeup_reason {
sleep(Duration::from_millis(1000)).await;
wakeup_reason = WakeupReason::Stale;
}
let rc = match wakeup_reason {
WakeupReason::WSSRxMsg(res) => {
last_contact = Instant::now();
self.process_wss_rx_msg(res).await
},
WakeupReason::MboxRx(mbox_message) => {
self.process_sink_mbox_rx_msg(&mut sink, mbox_message).await
},
WakeupReason::Stale => {
self.process_stale_connection_msg(last_contact.clone()).await
},
_ => {
panic!("Failed to get wakeup reason for {} conn", self.addr);
},
};
match rc {
Err(e) => {
warn!("{}", e);
break;
},
Ok(state) => {
if let CGWConnectionState::IsActive = state {
continue;
} else if let CGWConnectionState::IsForcedToClose = state {
// Return, because server already closed our mbox tx counterpart (rx),
// hence we don't need to send ConnectionClosed message. Server
// already knows we're closed.
return;
} else if let CGWConnectionState::ClosedGracefully = state {
warn!("Remote client {} closed connection gracefully", self.serial.clone().unwrap());
break;
} else if let CGWConnectionState::IsStale = state {
warn!("Remote client {} closed due to inactivity", self.serial.clone().unwrap());
break;
}
},
}
}
let mac = self.serial.clone().unwrap();
let msg = CGWConnectionServerReqMsg::ConnectionClosed(self.serial.unwrap());
self.cgw_server.enqueue_mbox_message_to_cgw_server(msg).await;
debug!("MBOX_OUT: ConnectionClosed, processor (mac:{})", mac);
}
}

View File

@@ -0,0 +1,878 @@
use crate::{
AppArgs,
};
use crate::cgw_nb_api_listener::{
CGWNBApiClient,
};
use crate::cgw_connection_processor::{
CGWConnectionProcessor,
CGWConnectionProcessorReqMsg,
};
use crate::cgw_db_accessor::{
CGWDBInfrastructureGroup,
};
use crate::cgw_remote_discovery::{
CGWRemoteDiscovery,
};
use tokio::{
sync::{
RwLock,
mpsc::{
UnboundedSender,
UnboundedReceiver,
unbounded_channel,
},
},
time::{
sleep,
Duration,
},
net::{
TcpStream,
},
runtime::{
Builder,
Runtime,
},
};
use std::{
collections::HashMap,
net::SocketAddr,
sync::{
Arc,
},
};
use std::sync::atomic::{
AtomicUsize,
Ordering,
};
use serde_json::{
Map,
Value,
};
use serde::{
Deserialize,
Serialize,
};
use uuid::{
Uuid,
};
type DeviceSerial = String;
type CGWConnmapType = Arc<RwLock<HashMap<String, UnboundedSender<CGWConnectionProcessorReqMsg>>>>;
#[derive(Debug)]
struct CGWConnMap {
map: CGWConnmapType,
}
impl CGWConnMap {
pub fn new() -> Self {
let hash_map: HashMap<String, UnboundedSender<CGWConnectionProcessorReqMsg>> = HashMap::new();
let map: Arc<RwLock<HashMap<String, UnboundedSender<CGWConnectionProcessorReqMsg>>>> = Arc::new(RwLock::new(hash_map));
let connmap = CGWConnMap {
map: map,
};
connmap
}
}
type CGWConnectionServerMboxRx = UnboundedReceiver<CGWConnectionServerReqMsg>;
type CGWConnectionServerMboxTx = UnboundedSender<CGWConnectionServerReqMsg>;
type CGWConnectionServerNBAPIMboxTx = UnboundedSender<CGWConnectionNBAPIReqMsg>;
type CGWConnectionServerNBAPIMboxRx = UnboundedReceiver<CGWConnectionNBAPIReqMsg>;
// The following pair used internally by server itself to bind
// Processor's Req/Res
#[derive(Debug)]
pub enum CGWConnectionServerReqMsg {
// Connection-related messages
AddNewConnection(DeviceSerial, UnboundedSender<CGWConnectionProcessorReqMsg>),
ConnectionClosed(DeviceSerial),
}
#[derive(Debug)]
pub enum CGWConnectionNBAPIReqMsgOrigin {
FromNBAPI,
FromRemoteCGW,
}
#[derive(Debug)]
pub enum CGWConnectionNBAPIReqMsg {
// Enqueue Key, Request, bool = isMessageRelayed
EnqueueNewMessageFromNBAPIListener(String, String, CGWConnectionNBAPIReqMsgOrigin),
}
pub struct CGWConnectionServer {
local_cgw_id: i32,
// CGWConnectionServer write into this mailbox,
// and other correspondig Server task Reads RX counterpart
mbox_internal_tx: CGWConnectionServerMboxTx,
// Object that owns underlying mac:connection map
connmap: CGWConnMap,
// Runtime that schedules all the WSS-messages related tasks
wss_rx_tx_runtime: Arc<Runtime>,
// Dedicated runtime (threadpool) for handling internal mbox:
// ACK/nACK connection, handle duplicates (clone/open) etc.
mbox_internal_runtime_handle: Arc<Runtime>,
// Dedicated runtime (threadpool) for handling NB-API mbox:
// RX NB-API requests, parse, relay (if needed)
mbox_nb_api_runtime_handle: Arc<Runtime>,
// Dedicated runtime (threadpool) for handling NB-API TX routine:
// TX NB-API requests (if async send is needed)
mbox_nb_api_tx_runtime_handle: Arc<Runtime>,
// Dedicated runtime (threadpool) for handling (relaying) msgs:
// relay-task is spawned inside it, and the produced stream of
// remote-cgw messages is being relayed inside this context
mbox_relay_msg_runtime_handle: Arc<Runtime>,
// CGWConnectionServer write into this mailbox,
// and other correspondig NB API client is responsible for doing an RX over
// receive handle counterpart
nb_api_client: Arc<CGWNBApiClient>,
// Interface used to access all discovered CGW instances
// (used for relaying non-local CGW requests from NB-API to target CGW)
cgw_remote_discovery: Arc<CGWRemoteDiscovery>,
// Handler that helps this object to wrap relayed NB-API messages
// dedicated for this particular local CGW instance
mbox_relayed_messages_handle: CGWConnectionServerNBAPIMboxTx,
}
/*
* TODO: split into base struct + enum type field
* this requires alot of refactoring in places where msg is used
* this is needed to always have uuid of malformed / discarded msg.
* e.g.
* struct CGWNBApiParsedMsg {
* uuid: Uuid,
* gid: i32,
* type: enum CGWNBApiParsedMsgType,
* }
*
* enum CGWNBApiParsedMsgType {
* InfrastructureGroupCreate,
* ...
* InfrastructureGroupInfraMsg(DeviceSerial, String),
* }
*/
enum CGWNBApiParsedMsg {
// TODO: fix kafka_simulator to provide reserved_size
InfrastructureGroupCreate(Uuid, i32),
InfrastructureGroupDelete(Uuid, i32),
InfrastructureGroupInfraAdd(Uuid, i32, Vec<DeviceSerial>),
InfrastructureGroupInfraDel(Uuid, i32, Vec<DeviceSerial>),
InfrastructureGroupInfraMsg(Uuid, i32, DeviceSerial, String),
}
impl CGWConnectionServer {
pub async fn new(app_args: &AppArgs) -> Arc<Self> {
let wss_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(app_args.wss_t_num)
.thread_name_fn(|| {
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("cgw-wss-t-{}", id)
})
.thread_stack_size(3 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let internal_mbox_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(1)
.thread_name("cgw-mbox")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let nb_api_mbox_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(1)
.thread_name("cgw-mbox-nbapi")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let relay_msg_mbox_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(1)
.thread_name("cgw-relay-mbox-nbapi")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let nb_api_mbox_tx_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(1)
.thread_name("cgw-mbox-nbapi-tx")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let (internal_tx, internal_rx) = unbounded_channel::<CGWConnectionServerReqMsg>();
let (nb_api_tx, nb_api_rx) = unbounded_channel::<CGWConnectionNBAPIReqMsg>();
// Give NB API client a handle where it can do a TX (CLIENT -> CGW_SERVER)
// RX is handled in internal_mbox of CGW_Server
let nb_api_c = CGWNBApiClient::new(app_args, &nb_api_tx);
let server = Arc::new(CGWConnectionServer {
local_cgw_id: app_args.cgw_id,
connmap: CGWConnMap::new(),
wss_rx_tx_runtime: wss_runtime_handle,
mbox_internal_runtime_handle: internal_mbox_runtime_handle,
mbox_nb_api_runtime_handle: nb_api_mbox_runtime_handle,
mbox_nb_api_tx_runtime_handle: nb_api_mbox_tx_runtime_handle,
mbox_internal_tx: internal_tx,
nb_api_client: nb_api_c,
cgw_remote_discovery: Arc::new(CGWRemoteDiscovery::new(app_args).await),
mbox_relayed_messages_handle: nb_api_tx,
mbox_relay_msg_runtime_handle: relay_msg_mbox_runtime_handle,
});
let server_clone = server.clone();
// Task for processing mbox_internal_rx, task owns the RX part
server.mbox_internal_runtime_handle.spawn(async move {
server_clone.process_internal_mbox(internal_rx).await;
});
let server_clone = server.clone();
server.mbox_nb_api_runtime_handle.spawn(async move {
server_clone.process_internal_nb_api_mbox(nb_api_rx).await;
});
server
}
pub async fn enqueue_mbox_message_to_cgw_server(&self, req: CGWConnectionServerReqMsg) {
let _ = self.mbox_internal_tx.send(req);
}
pub fn enqueue_mbox_message_from_device_to_nb_api_c(&self, mac: DeviceSerial, req: String) {
// TODO: device (mac) -> group id matching
let key = String::from("TBD_INFRA_GROUP");
let nb_api_client_clone = self.nb_api_client.clone();
tokio::spawn(async move {
let _ = nb_api_client_clone.enqueue_mbox_message_from_cgw_server(key, req).await;
});
}
pub fn enqueue_mbox_message_from_cgw_to_nb_api(&self, gid: i32, req: String) {
let nb_api_client_clone = self.nb_api_client.clone();
self.mbox_nb_api_tx_runtime_handle.spawn(async move {
let _ = nb_api_client_clone.enqueue_mbox_message_from_cgw_server(gid.to_string(), req).await;
});
}
pub async fn enqueue_mbox_relayed_message_to_cgw_server(&self, key: String, req: String) {
let msg = CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, req, CGWConnectionNBAPIReqMsgOrigin::FromRemoteCGW);
let _ = self.mbox_relayed_messages_handle.send(msg);
}
fn parse_nbapi_msg(&self, pload: &String) -> Option<CGWNBApiParsedMsg> {
#[derive(Debug, Serialize, Deserialize)]
struct InfraGroupCreate {
r#type: String,
infra_group_id: String,
infra_name: String,
infra_shard_id: i32,
uuid: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
struct InfraGroupDelete {
r#type: String,
infra_group_id: String,
uuid: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
struct InfraGroupInfraAdd {
r#type: String,
infra_group_id: String,
infra_group_infra_devices: Vec<String>,
uuid: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
struct InfraGroupInfraDel {
r#type: String,
infra_group_id: String,
infra_group_infra_devices: Vec<String>,
uuid: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
struct InfraGroupMsgJSON {
r#type: String,
infra_group_id: String,
mac: String,
msg: Map<String, Value>,
uuid: Uuid,
}
let rc = serde_json::from_str(pload);
if let Err(e) = rc {
error!("{e}\n{pload}");
return None;
}
let map: Map<String, Value> = rc.unwrap();
let rc = map.get(&String::from("type"));
if let None = rc {
error!("No msg_type found in\n{pload}");
return None;
}
let rc = rc.unwrap();
let msg_type = rc.as_str().unwrap();
let rc = map.get(&String::from("infra_group_id"));
if let None = rc {
error!("No infra_group_id found in\n{pload}");
return None;
}
let rc = rc.unwrap();
let group_id: i32 = rc.as_str().unwrap().parse().unwrap();
//debug!("Got msg {msg_type}, infra {group_id}");
match msg_type {
"infrastructure_group_create" => {
let json_msg: InfraGroupCreate = serde_json::from_str(&pload).unwrap();
//debug!("{:?}", json_msg);
return Some(CGWNBApiParsedMsg::InfrastructureGroupCreate(json_msg.uuid, group_id));
},
"infrastructure_group_delete" => {
let json_msg: InfraGroupDelete = serde_json::from_str(&pload).unwrap();
//debug!("{:?}", json_msg);
return Some(CGWNBApiParsedMsg::InfrastructureGroupDelete(json_msg.uuid, group_id));
},
"infrastructure_group_device_add" => {
let json_msg: InfraGroupInfraAdd = serde_json::from_str(&pload).unwrap();
//debug!("{:?}", json_msg);
return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraAdd(json_msg.uuid, group_id, json_msg.infra_group_infra_devices));
}
"infrastructure_group_device_del" => {
let json_msg: InfraGroupInfraDel = serde_json::from_str(&pload).unwrap();
//debug!("{:?}", json_msg);
return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraDel(json_msg.uuid, group_id, json_msg.infra_group_infra_devices));
}
"infrastructure_group_device_message" => {
let json_msg: InfraGroupMsgJSON = serde_json::from_str(&pload).unwrap();
//debug!("{:?}", json_msg);
return Some(CGWNBApiParsedMsg::InfrastructureGroupInfraMsg(json_msg.uuid, group_id, json_msg.mac, serde_json::to_string(&json_msg.msg).unwrap()));
},
&_ => {
debug!("Unknown type {msg_type} received");
}
}
None
}
async fn process_internal_nb_api_mbox(self: Arc<Self>, mut rx_mbox: CGWConnectionServerNBAPIMboxRx) {
debug!("process_nb_api_mbox entry");
let buf_capacity = 2000;
let mut buf: Vec<CGWConnectionNBAPIReqMsg> = Vec::with_capacity(buf_capacity);
let mut num_of_msg_read = 0;
// As of now, expect at max 100 CGWS remote instances without buffers realloc
// This only means that original capacity of all buffers is allocated to <100>,
// it can still increase on demand or need automatically (upon insert, push_back etc)
let cgw_buf_prealloc_size = 100;
let mut local_parsed_cgw_msg_buf: Vec<CGWNBApiParsedMsg> = Vec::with_capacity(buf_capacity);
loop {
if num_of_msg_read < buf_capacity {
// Try to recv_many, but don't sleep too much
// in case if no messaged pending and we have
// TODO: rework to pull model (pull on demand),
// compared to curr impl: push model (nb api listener forcefully
// pushed all fetched data from kafka).
// Currently recv_many may sleep if previous read >= 1,
// but no new messages pending
//
// It's also possible that this logic staggers the processing,
// in case when every new message is received <=9 ms for example:
// Single message received, waiting for new up to 10 ms.
// New received on 9th ms. Repeat.
// And this could repeat up untill buffer is full, or no new messages
// appear on the 10ms scale.
// Highly unlikly scenario, but still possible.
let rd_num = tokio::select! {
v = rx_mbox.recv_many(&mut buf, buf_capacity - num_of_msg_read) => {
v
}
v = sleep(Duration::from_millis(10)) => {
0
}
};
num_of_msg_read += rd_num;
// We read some messages, try to continue and read more
// If none read - break from recv, process all buffers that've
// been filled-up so far (both local and remote).
// Upon done - repeat.
if rd_num >= 1 {
continue;
} else {
if num_of_msg_read == 0 {
continue;
}
}
}
debug!("Received {num_of_msg_read} messages from NB API, processing...");
// We rely on this map only for a single iteration of received messages:
// say, we receive 10 messages but 20 in queue, this means that gid->cgw_id
// cache is clear at first, the filled up when processing first 10 messages,
// the clear/reassigned again for next 10 msgs (10->20).
// This is done to ensure that we don't fallback for redis too much,
// but still somewhat fully rely on it.
//
self.cgw_remote_discovery.sync_gid_to_cgw_map().await;
local_parsed_cgw_msg_buf.clear();
// TODO: rework to avoid re-allocating these buffers on each loop iteration
// (get mut slice of vec / clear when done?)
let mut relayed_cgw_msg_buf: Vec<(i32, CGWConnectionNBAPIReqMsg)> = Vec::with_capacity(num_of_msg_read + 1);
let mut local_cgw_msg_buf: Vec<CGWConnectionNBAPIReqMsg> = Vec::with_capacity(num_of_msg_read + 1);
while ! buf.is_empty() {
let msg = buf.remove(0);
if let CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin) = msg {
let gid_numeric = match key.parse::<i32>() {
Err(e) => {
warn!("Invalid KEY received from KAFKA bus message, ignoring\n{e}");
continue;
},
Ok(v) => v
};
let parsed_msg = match self.parse_nbapi_msg(&payload) {
Some(val) => val,
None => {
warn!("Failed to parse recv msg with key {key}, discarded");
continue;
}
};
// The one shard that received add/del is responsible for
// handling it at place.
// Any other msg is either relayed / handled locally later.
// The reason for this is following: current shard is responsible
// for assignment of GID to shard, thus it has to make
// assignment as soon as possible to deduce relaying action in
// the following message pool that is being handled.
// Same for delete.
if let CGWNBApiParsedMsg::InfrastructureGroupCreate(uuid, gid) = parsed_msg {
// DB stuff - create group for remote shards to be aware of change
let group = CGWDBInfrastructureGroup {
id: gid,
reserved_size: 1000i32,
actual_size: 0i32,
};
match self.cgw_remote_discovery.create_infra_group(&group).await {
Ok(dst_cgw_id) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Group has been created successfully gid {gid}"));
},
Err(e) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to create new group (duplicate create?), gid {gid}"));
warn!("Create group gid {gid} received, but it already exists");
}
}
// This type of msg is handled in place, not added to buf
// for later processing.
continue;
} else if let CGWNBApiParsedMsg::InfrastructureGroupDelete(uuid, gid) = parsed_msg {
match self.cgw_remote_discovery.destroy_infra_group(gid).await {
Ok(()) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Group has been destroyed successfully gid {gid}, uuid {uuid}"));
},
Err(e) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to destroy group (doesn't exist?), gid {gid}, uuid {uuid}"));
warn!("Destroy group gid {gid} received, but it does not exist");
}
}
// This type of msg is handled in place, not added to buf
// for later processing.
continue;
}
// We received NB API msg, check origin:
// If it's a relayed message, we must not relay it further
// If msg originated from Kafka originally, it's safe to relay it (if needed)
if let CGWConnectionNBAPIReqMsgOrigin::FromRemoteCGW = origin {
local_parsed_cgw_msg_buf.push(parsed_msg);
continue;
}
match self.cgw_remote_discovery.get_infra_group_owner_id(key.parse::<i32>().unwrap()).await {
Some(dst_cgw_id) => {
if dst_cgw_id == self.local_cgw_id {
local_cgw_msg_buf.push(CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin));
} else {
relayed_cgw_msg_buf.push((dst_cgw_id, CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, origin)));
}
},
None => {
warn!("Received msg for gid {gid_numeric}, while this group is unassigned to any of CGWs: rejecting");
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid_numeric,
format!("Received message for unknown group {gid_numeric} - unassigned?"));
}
}
}
}
let discovery_clone = self.cgw_remote_discovery.clone();
let self_clone = self.clone();
// Future to Handle (relay) messages for remote CGW
let relay_task_hdl = self.mbox_relay_msg_runtime_handle.spawn(async move {
let mut remote_cgws_map: HashMap<String, (i32, Vec<(String, String)>)> = HashMap::with_capacity(cgw_buf_prealloc_size);
while ! relayed_cgw_msg_buf.is_empty() {
let msg = relayed_cgw_msg_buf.remove(0);
if let (dst_cgw_id, CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, _origin)) = msg {
debug!("Received MSG for remote CGW k:{}, local id {} relaying msg to remote...", key, self_clone.local_cgw_id);
if let Some(v) = remote_cgws_map.get_mut(&key) {
v.1.push((key, payload));
} else {
let mut tmp_vec: Vec<(String, String)> = Vec::with_capacity(num_of_msg_read);
tmp_vec.push((key.clone(), payload));
remote_cgws_map.insert(key, (dst_cgw_id, tmp_vec));
}
}
}
for value in remote_cgws_map.into_values() {
let discovery_clone = discovery_clone.clone();
let cgw_id = value.0;
let msg_stream = value.1;
let self_clone = self_clone.clone();
tokio::spawn(async move {
if let Err(()) = discovery_clone.relay_request_stream_to_remote_cgw(cgw_id, msg_stream).await {
self_clone.enqueue_mbox_message_from_cgw_to_nb_api(
-1,
format!("Failed to relay MSG stream to remote CGW{cgw_id}, UUIDs: not implemented (TODO)"));
}
});
}
});
// Handle messages for local CGW
// Parse all messages first, then process
// TODO: try to parallelize at least parsing of msg:
// iterate each msg, get index, spawn task that would
// write indexed parsed msg into output parsed msg buf.
let connmap_clone = self.connmap.map.clone();
while ! local_cgw_msg_buf.is_empty() {
let msg = local_cgw_msg_buf.remove(0);
if let CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, _origin) = msg {
let gid_numeric: i32 = key.parse::<i32>().unwrap();
debug!("Received message for local CGW k:{key}, local id {}", self.local_cgw_id);
let msg = self.parse_nbapi_msg(&payload);
if let None = msg {
error!("Failed to parse msg from NBAPI (malformed?)");
continue;
}
match msg.unwrap() {
CGWNBApiParsedMsg::InfrastructureGroupInfraAdd(uuid, gid, mac_list) => {
if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await {
warn!("Unexpected: tried to add infra list to nonexisting group (gid {gid}, uuid {uuid}");
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to insert MACs from infra list, gid {gid}, uuid {uuid}: group does not exist."));
}
match self.cgw_remote_discovery.create_ifras_list(gid, mac_list).await {
Ok(()) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Infra list has been created successfully gid {gid}, uuid {uuid}"));
},
Err(macs) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to insert few MACs from infra list, gid {gid}, uuid {uuid}; List of failed MACs:{}",
macs.iter().map(|x| x.to_string() + ",").collect::<String>()));
warn!("Failed to create few MACs from infras list (partial create)");
continue;
}
}
},
CGWNBApiParsedMsg::InfrastructureGroupInfraDel(uuid, gid, mac_list) => {
if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await {
warn!("Unexpected: tried to delete infra list from nonexisting group (gid {gid}, uuid {uuid}");
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to delete MACs from infra list, gid {gid}, uuid {uuid}: group does not exist."));
}
match self.cgw_remote_discovery.destroy_ifras_list(gid, mac_list).await {
Ok(()) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Infra list has been destroyed successfully gid {gid}, uuid {uuid}"));
},
Err(macs) => {
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to destroy few MACs from infra list (not created?), gid {gid}, uuid {uuid}; List of failed MACs:{}",
macs.iter().map(|x| x.to_string() + ",").collect::<String>()));
warn!("Failed to destroy few MACs from infras list (partial delete)");
continue;
}
}
},
CGWNBApiParsedMsg::InfrastructureGroupInfraMsg(uuid, gid, mac, msg) => {
if let None = self.cgw_remote_discovery.get_infra_group_owner_id(gid_numeric).await {
warn!("Unexpected: tried to sink down msg to device of nonexisting group (gid {gid}, uuid {uuid}");
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to sink down msg to device of nonexisting group, gid {gid}, uuid {uuid}: group does not exist."));
}
debug!("Sending msg to device {mac}");
let rd_lock = connmap_clone.read().await;
let rc = rd_lock.get(&mac);
if let None = rc {
error!("Cannot find suitable connection for {mac}");
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to send msg (device not connected?), gid {gid}, uuid {uuid}"));
continue;
}
let proc_mbox_tx = rc.unwrap();
if let Err(e) = proc_mbox_tx.send(CGWConnectionProcessorReqMsg::SinkRequestToDevice(mac, msg)) {
error!("Failed to send message to remote device (msg uuid({uuid}))");
self.enqueue_mbox_message_from_cgw_to_nb_api(
gid,
format!("Failed to send msg, gid {gid}, uuid {uuid}"));
}
},
_ => {
debug!("Received unimplemented/unexpected group create/del msg, ignoring");
}
}
}
}
// Do not proceed parsing local / remote msgs untill previous relaying has been
// finished
tokio::join!(relay_task_hdl);
buf.clear();
num_of_msg_read = 0;
}
panic!("RX or TX counterpart of nb_api channel part destroyed, while processing task is still active");
}
async fn process_internal_mbox(self: Arc<Self>, mut rx_mbox: CGWConnectionServerMboxRx) {
debug!("process_internal_mbox entry");
let buf_capacity = 1000;
let mut buf: Vec<CGWConnectionServerReqMsg> = Vec::with_capacity(buf_capacity);
let mut num_of_msg_read = 0;
loop {
if num_of_msg_read < buf_capacity {
// Try to recv_many, but don't sleep too much
// in case if no messaged pending and we have
// TODO: rework?
// Currently recv_many may sleep if previous read >= 1,
// but no new messages pending
let rd_num = tokio::select! {
v = rx_mbox.recv_many(&mut buf, buf_capacity - num_of_msg_read) => {
v
}
v = sleep(Duration::from_millis(10)) => {
0
}
};
num_of_msg_read += rd_num;
// We read some messages, try to continue and read more
// If none read - break from recv, process all buffers that've
// been filled-up so far (both local and remote).
// Upon done - repeat.
if rd_num >= 1 {
if num_of_msg_read < 100 {
continue;
}
} else {
if num_of_msg_read == 0 {
continue;
}
}
}
let mut connmap_w_lock = self.connmap.map.write().await;
while ! buf.is_empty() {
let msg = buf.remove(0);
if let CGWConnectionServerReqMsg::AddNewConnection(serial, conn_processor_mbox_tx) = msg {
// if connection is unique: simply insert new conn
//
// if duplicate exists: notify server about such incident.
// it's up to server to notify underlying task that it should
// drop the connection.
// from now on simply insert new connection into hashmap and proceed on
// processing it.
let serial_clone: DeviceSerial = serial.clone();
if let Some(c) = connmap_w_lock.remove(&serial_clone) {
tokio::spawn(async move {
warn!("Duplicate connection (mac:{}) detected, closing OLD connection in favor of NEW", serial_clone);
let msg: CGWConnectionProcessorReqMsg = CGWConnectionProcessorReqMsg::AddNewConnectionShouldClose;
c.send(msg).unwrap();
});
}
// clone a sender handle, as we still have to send ACK back using underlying
// tx mbox handle
let conn_processor_mbox_tx_clone = conn_processor_mbox_tx.clone();
info!("connmap: connection with {} established, new num_of_connections:{}", serial, connmap_w_lock.len() + 1);
connmap_w_lock.insert(serial, conn_processor_mbox_tx);
tokio::spawn(async move {
let msg: CGWConnectionProcessorReqMsg = CGWConnectionProcessorReqMsg::AddNewConnectionAck;
conn_processor_mbox_tx_clone.send(msg).unwrap();
});
} else if let CGWConnectionServerReqMsg::ConnectionClosed(serial) = msg {
info!("connmap: removed {} serial from connmap, new num_of_connections:{}", serial, connmap_w_lock.len() - 1);
connmap_w_lock.remove(&serial);
}
}
buf.clear();
num_of_msg_read = 0;
}
panic!("RX or TX counterpart of mbox_internal channel part destroyed, while processing task is still active");
}
pub async fn ack_connection(self: Arc<Self>, socket: TcpStream, tls_acceptor: tokio_native_tls::TlsAcceptor, addr: SocketAddr, conn_idx: i64) {
// Only ACK connection. We will either drop it or accept it once processor starts
// (we'll handle it via "mailbox" notify handle in process_internal_mbox)
let server_clone = self.clone();
self.wss_rx_tx_runtime.spawn(async move {
// Accept the TLS connection.
let tls_stream = match tls_acceptor.accept(socket).await {
Ok(a) => a,
Err(e) => {
warn!("Err {e}");
return;
}
};
let conn_processor = CGWConnectionProcessor::new(server_clone, conn_idx, addr);
conn_processor.start(tls_stream).await;
});
}
}
#[cfg(test)]
mod tests {
use super::*;
fn get_connect_json_msg() -> &'static str {
r#"
{
"jsonrpc": "2.0",
"method": "connect",
"params": {
"serial": "00000000ca4b",
"firmware": "SONiC-OS-4.1.0_vs_daily_221213_1931_422-campus",
"uuid": 1,
"capabilities": {
"compatible": "+++x86_64-kvm_x86_64-r0",
"model": "DellEMC-S5248f-P-25G-DPB",
"platform": "switch",
"label_macaddr": "00:00:00:00:ca:4b"
}
}
}"#
}
fn get_log_json_msg() -> &'static str {
r#"
{
"jsonrpc": "2.0",
"method": "log",
"params": {
"serial": "00000000ca4b",
"log": "uc-client: connection error: Unable to connect",
"severity": 3
}
}"#
}
#[test]
fn can_parse_connect_event() {
let msg = get_connect_json_msg();
let map: Map<String, Value> =
serde_json::from_str(msg).expect("Failed to parse input json");
let method = map["method"].as_str().unwrap();
let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string());
match event {
CGWEvent::Connect(_) => {
assert!(true);
}
_ => {
assert!(false, "Expected event to be of <Connect> type");
}
}
}
#[test]
fn can_parse_log_event() {
let msg = get_log_json_msg();
let map: Map<String, Value> =
serde_json::from_str(msg).expect("Failed to parse input json");
let method = map["method"].as_str().unwrap();
let event: CGWEvent = cgw_parse_jrpc_event(&map, method.to_string());
match event {
CGWEvent::Log(_) => {
assert!(true);
}
_ => {
assert!(false, "Expected event to be of <Log> type");
}
}
}
}

205
src/cgw_db_accessor.rs Normal file
View File

@@ -0,0 +1,205 @@
use crate::{
AppArgs,
};
use eui48::{
MacAddress,
};
use tokio_postgres::{
Client,
NoTls,
row::{
Row,
},
};
#[derive(Clone)]
pub struct CGWDBInfra {
pub mac: MacAddress,
pub infra_group_id: i32,
}
#[derive(Clone)]
pub struct CGWDBInfrastructureGroup {
pub id: i32,
pub reserved_size: i32,
pub actual_size: i32,
}
impl From<Row> for CGWDBInfra {
fn from(row: Row) -> Self {
let serial: MacAddress = row.get("mac");
let gid: i32 = row.get("infra_group_id");
Self {
mac: serial,
infra_group_id: gid,
}
}
}
impl From<Row> for CGWDBInfrastructureGroup {
fn from(row: Row) -> Self {
let infra_id: i32 = row.get("id");
let res_size: i32 = row.get("reserved_size");
let act_size: i32 = row.get("actual_size");
Self {
id: infra_id,
reserved_size: res_size,
actual_size: act_size,
}
}
}
pub struct CGWDBAccessor {
cl: Client,
}
impl CGWDBAccessor {
pub async fn new(app_args: &AppArgs) -> Self {
let conn_str = format!("host={host} port={port} user={user} dbname={db} password={pass}",
host = app_args.db_ip,
port = app_args.db_port,
user = app_args.db_username,
db = app_args.db_name,
pass = app_args.db_password);
debug!("Trying to connect to remote db ({}:{})...",
app_args.db_ip.to_string(),
app_args.db_port.to_string());
debug!("Conn args {conn_str}");
let (client, connection) =
tokio_postgres::connect(&conn_str, NoTls).await.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
info!("Connected to remote DB");
CGWDBAccessor {
cl: client,
}
}
/*
* INFRA_GROUP db API uses the following table decl
* TODO: id = int, not varchar; requires kafka simulator changes
CREATE TABLE infrastructure_groups (
id VARCHAR (50) PRIMARY KEY,
reserved_size INT,
actual_size INT
);
*
*/
pub async fn insert_new_infra_group(&self, g: &CGWDBInfrastructureGroup) -> Result<(), &'static str> {
let q = self.cl.prepare("INSERT INTO infrastructure_groups (id, reserved_size, actual_size) VALUES ($1, $2, $3)").await.unwrap();
let res = self.cl.execute(&q, &[&g.id, &g.reserved_size, &g.actual_size]).await;
match res {
Ok(n) => return Ok(()),
Err(e) => {
error!("Failed to insert a new infra group {}: {:?}", g.id, e.to_string());
return Err("Insert new infra group failed");
}
}
}
pub async fn delete_infra_group(&self, gid: i32) -> Result<(), &'static str> {
// TODO: query-base approach instead of static string
let req = self.cl.prepare("DELETE FROM infrastructure_groups WHERE id = $1").await.unwrap();
let res = self.cl.execute(&req, &[&gid]).await;
match res {
Ok(n) => {
if n > 0 {
return Ok(());
} else {
return Err("Failed to delete group from DB: gid does not exist");
}
},
Err(e) => {
error!("Failed to delete an infra group {gid}: {:?}", e.to_string());
return Err("Delete infra group failed");
}
}
}
pub async fn get_all_infra_groups(&self) -> Option<Vec<CGWDBInfrastructureGroup>> {
let mut list: Vec<CGWDBInfrastructureGroup> = Vec::with_capacity(1000);
let res = self.cl.query("SELECT * from infrastructure_groups", &[]).await;
match res {
Ok(r) => {
for x in r {
let infra_group = CGWDBInfrastructureGroup::from(x);
list.push(infra_group);
}
return Some(list);
}
Err(e) => {
return None;
}
}
}
pub async fn get_infra_group(&self, gid: i32) -> Option<CGWDBInfrastructureGroup> {
let q = self.cl.prepare("SELECT * from infrastructure_groups WHERE id = $1").await.unwrap();
let row = self.cl.query_one(&q, &[&gid]).await;
match row {
Ok(r) => return Some(CGWDBInfrastructureGroup::from(r)),
Err(e) => {
return None;
}
}
}
/*
* INFRA db API uses the following table decl
CREATE TABLE infras (
mac MACADDR PRIMARY KEY,
infra_group_id INT,
FOREIGN KEY(infra_group_id) REFERENCES infrastructure_groups(id) ON DELETE CASCADE
);
*/
pub async fn insert_new_infra(&self, infra: &CGWDBInfra) -> Result<(), &'static str> {
let q = self.cl.prepare("INSERT INTO infras (mac, infra_group_id) VALUES ($1, $2)").await.unwrap();
let res = self.cl.execute(
&q,
&[&infra.mac, &infra.infra_group_id]).await;
match res {
Ok(n) => return Ok(()),
Err(e) => {
error!("Failed to insert a new infra: {:?}", e.to_string());
return Err("Insert new infra failed");
}
}
}
pub async fn delete_infra(&self, serial: MacAddress) -> Result<(), &'static str> {
let q = self.cl.prepare("DELETE FROM infras WHERE mac = $1").await.unwrap();
let res = self.cl.execute(
&q,
&[&serial]).await;
match res {
Ok(n) => {
if n > 0 {
return Ok(());
} else {
return Err("Failed to delete infra from DB: MAC does not exist");
}
},
Err(e) => {
error!("Failed to delete infra: {:?}", e.to_string());
return Err("Delete infra failed");
}
}
}
}

260
src/cgw_nb_api_listener.rs Normal file
View File

@@ -0,0 +1,260 @@
use crate::{
AppArgs,
};
use crate::cgw_connection_server::{
CGWConnectionNBAPIReqMsg,
CGWConnectionNBAPIReqMsgOrigin,
};
use std::{
sync::{
Arc,
},
};
use tokio::{
sync::{
mpsc::{
UnboundedSender,
},
},
time::{
Duration,
},
runtime::{
Builder,
Runtime,
},
};
use futures::stream::{TryStreamExt};
use rdkafka::client::ClientContext;
use rdkafka::config::{ClientConfig, RDKafkaLogLevel};
use rdkafka::{
consumer::{
Consumer,
ConsumerContext,
Rebalance,
stream_consumer::{
StreamConsumer,
},
},
producer::{
FutureProducer,
FutureRecord,
},
};
use rdkafka::error::KafkaResult;
use rdkafka::message::{Message};
use rdkafka::topic_partition_list::TopicPartitionList;
type CGWConnectionServerMboxTx = UnboundedSender<CGWConnectionNBAPIReqMsg>;
type CGWCNCConsumerType = StreamConsumer<CustomContext>;
type CGWCNCProducerType = FutureProducer;
struct CustomContext;
impl ClientContext for CustomContext {}
impl ConsumerContext for CustomContext {
fn pre_rebalance(&self, rebalance: &Rebalance) {
let mut part_list = String::new();
if let rdkafka::consumer::Rebalance::Assign(partitions) = rebalance {
for x in partitions.elements() {
part_list += &(x.partition().to_string() + " ");
}
debug!("pre_rebalance callback, assigned partition(s): {part_list}");
}
part_list.clear();
if let rdkafka::consumer::Rebalance::Revoke(partitions) = rebalance {
for x in partitions.elements() {
part_list += &(x.partition().to_string() + " ");
}
debug!("pre_rebalance callback, revoked partition(s): {part_list}");
}
}
fn post_rebalance(&self, rebalance: &Rebalance) {
let mut part_list = String::new();
if let rdkafka::consumer::Rebalance::Assign(partitions) = rebalance {
for x in partitions.elements() {
part_list += &(x.partition().to_string() + " ");
}
debug!("post_rebalance callback, assigned partition(s): {part_list}");
}
part_list.clear();
if let rdkafka::consumer::Rebalance::Revoke(partitions) = rebalance {
for x in partitions.elements() {
part_list += &(x.partition().to_string() + " ");
}
debug!("post_rebalance callback, revoked partition(s): {part_list}");
}
}
fn commit_callback(&self, result: KafkaResult<()>, _offsets: &TopicPartitionList) {
let mut part_list = String::new();
for x in _offsets.elements() {
part_list += &(x.partition().to_string() + " ");
}
debug!("commit_callback callback, partition(s): {part_list}");
debug!("Consumer callback: commited offset");
}
}
static GROUP_ID: &'static str = "CGW";
const CONSUMER_TOPICS: [&'static str; 1] = ["CnC"];
const PRODUCER_TOPICS: &'static str = "CnC_Res";
struct CGWCNCProducer {
p: CGWCNCProducerType,
}
struct CGWCNCConsumer {
c: CGWCNCConsumerType,
}
impl CGWCNCConsumer {
pub fn new(app_args: &AppArgs) -> Self {
let consum: CGWCNCConsumerType = Self::create_consumer(app_args);
CGWCNCConsumer {
c: consum,
}
}
fn create_consumer(app_args: &AppArgs) -> CGWCNCConsumerType {
let context = CustomContext;
debug!("Trying to connect to kafka broker ({}:{})...",
app_args.kafka_ip.to_string(),
app_args.kafka_port.to_string());
let consumer: CGWCNCConsumerType = ClientConfig::new()
.set("group.id", GROUP_ID)
.set("client.id", GROUP_ID.to_string() + &app_args.cgw_id.to_string())
.set("group.instance.id", app_args.cgw_id.to_string())
.set("bootstrap.servers", app_args.kafka_ip.to_string() + ":" + &app_args.kafka_port.to_string())
.set("enable.partition.eof", "false")
.set("session.timeout.ms", "6000")
.set("enable.auto.commit", "true")
//.set("statistics.interval.ms", "30000")
//.set("auto.offset.reset", "smallest")
.set_log_level(RDKafkaLogLevel::Debug)
.create_with_context(context)
.expect("Consumer creation failed");
consumer
.subscribe(&CONSUMER_TOPICS)
.expect("Failed to subscribe to {CONSUMER_TOPICS} topics");
info!("Connected to kafka broker");
consumer
}
}
impl CGWCNCProducer {
pub fn new() -> Self {
let prod: CGWCNCProducerType = Self::create_producer();
CGWCNCProducer {
p: prod,
}
}
fn create_producer() -> CGWCNCProducerType {
let producer: FutureProducer = ClientConfig::new()
.set("bootstrap.servers", "172.20.10.136:9092")
.set("message.timeout.ms", "5000")
.create()
.expect("Producer creation error");
producer
}
}
pub struct CGWNBApiClient {
working_runtime_handle: Runtime,
cgw_server_tx_mbox: CGWConnectionServerMboxTx,
prod: CGWCNCProducer,
// TBD: stplit different implementators through a defined trait,
// that implements async R W operations?
}
impl CGWNBApiClient {
pub fn new(app_args: &AppArgs, cgw_tx: &CGWConnectionServerMboxTx) -> Arc<Self> {
let working_runtime_h = Builder::new_multi_thread()
.worker_threads(1)
.thread_name("cgw-nb-api-l")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap();
let cl = Arc::new(CGWNBApiClient {
working_runtime_handle: working_runtime_h,
cgw_server_tx_mbox: cgw_tx.clone(),
prod: CGWCNCProducer::new(),
});
let cl_clone = cl.clone();
let consumer: CGWCNCConsumer = CGWCNCConsumer::new(app_args);
cl.working_runtime_handle.spawn(async move {
loop {
let cl_clone = cl_clone.clone();
let stream_processor = consumer.c.stream().try_for_each(|borrowed_message| {
let cl_clone = cl_clone.clone();
async move {
// Process each message
// Borrowed messages can't outlive the consumer they are received from, so they need to
// be owned in order to be sent to a separate thread.
//record_owned_message_receipt(&owned_message).await;
let owned = borrowed_message.detach();
let key = match owned.key_view::<str>() {
None => "",
Some(Ok(s)) => s,
Some(Err(e)) => {
warn!("Error while deserializing message payload: {:?}", e);
""
}
};
let payload = match owned.payload_view::<str>() {
None => "",
Some(Ok(s)) => s,
Some(Err(e)) => {
warn!("Error while deserializing message payload: {:?}", e);
""
}
};
cl_clone.enqueue_mbox_message_to_cgw_server(key.to_string(), payload.to_string()).await;
Ok(())
}
});
stream_processor.await.expect("stream processing failed");
}
});
cl
}
pub async fn enqueue_mbox_message_from_cgw_server(&self, key: String, payload: String) {
let produce_future = self.prod.p.send(
FutureRecord::to(&PRODUCER_TOPICS)
.key(&key)
.payload(&payload),
Duration::from_secs(0),
);
match produce_future.await {
Err((e, _)) => println!("Error: {:?}", e),
_ => {}
}
}
async fn enqueue_mbox_message_to_cgw_server(&self, key: String, payload: String) {
debug!("MBOX_OUT: EnqueueNewMessageFromNBAPIListener, k:{key}");
let msg = CGWConnectionNBAPIReqMsg::EnqueueNewMessageFromNBAPIListener(key, payload, CGWConnectionNBAPIReqMsgOrigin::FromNBAPI);
let _ = self.cgw_server_tx_mbox.send(msg);
}
}

75
src/cgw_remote_client.rs Normal file
View File

@@ -0,0 +1,75 @@
use crate::{
AppArgs,
};
pub mod cgw_remote {
tonic::include_proto!("cgw.remote");
}
use tonic::{
transport::{
Uri,
channel::{
Channel,
},
},
};
use cgw_remote::{
EnqueueRequest,
remote_client::{
RemoteClient,
},
};
use tokio::{
time::{
Duration,
},
};
#[derive(Clone)]
pub struct CGWRemoteClient {
remote_client: RemoteClient<Channel>,
}
impl CGWRemoteClient {
pub fn new(hostname: String) -> Self {
let uri = Uri::from_maybe_shared(hostname).unwrap();
let r_channel = Channel::builder(uri)
.timeout(Duration::from_secs(20))
.connect_timeout(Duration::from_secs(20))
.connect_lazy();
let client = RemoteClient::new(r_channel);
CGWRemoteClient {
remote_client: client,
}
}
pub async fn relay_request_stream(&self, stream: Vec<(String, String)>) -> Result<(), ()> {
let mut cl_clone = self.remote_client.clone();
let mut messages: Vec<EnqueueRequest> = vec![];
let mut it = stream.into_iter();
while let Some(x) = it.next() {
messages.push(EnqueueRequest {
key: x.0,
req: x.1,
});
}
let rq = tonic::Request::new(tokio_stream::iter(messages.clone()));
match cl_clone.enqueue_nbapi_request_stream(rq).await {
Err(e) => {
error!("Failed to relay req: {:?}", e);
Err(())
}
Ok(r) => {
Ok(())
}
}
}
}

559
src/cgw_remote_discovery.rs Normal file
View File

@@ -0,0 +1,559 @@
use crate::{
AppArgs,
cgw_remote_client::{
CGWRemoteClient,
},
cgw_db_accessor::{
CGWDBAccessor,
CGWDBInfrastructureGroup,
CGWDBInfra,
},
};
use std::{
collections::{
HashMap,
},
net::{
Ipv4Addr,
IpAddr,
SocketAddr,
},
sync::{
Arc,
},
};
use redis_async::{
resp_array,
};
use eui48::{
MacAddress,
};
use tokio::{
sync::{
RwLock,
}
};
// Used in remote lookup
static REDIS_KEY_SHARD_ID_PREFIX:&'static str = "shard_id_";
static REDIS_KEY_SHARD_ID_FIELDS_NUM: usize = 12;
static REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM:&'static str = "assigned_groups_num";
// Used in group assign / reassign
static REDIS_KEY_GID:&'static str = "group_id_";
static REDIS_KEY_GID_VALUE_GID:&'static str = "gid";
static REDIS_KEY_GID_VALUE_SHARD_ID:&'static str = "shard_id";
#[derive(Clone, Debug)]
pub struct CGWREDISDBShard {
id: i32,
server_ip: IpAddr,
server_port: u16,
assigned_groups_num: i32,
capacity: i32,
threshold: i32,
}
impl From<Vec<String>> for CGWREDISDBShard {
fn from(values: Vec<String>) -> Self {
assert!(values.len() >= REDIS_KEY_SHARD_ID_FIELDS_NUM,
"Unexpected size of parsed vector: at least {REDIS_KEY_SHARD_ID_FIELDS_NUM} expected");
assert!(values[0] == "id", "redis.res[0] != id, unexpected.");
assert!(values[2] == "server_ip", "redis.res[2] != server_ip, unexpected.");
assert!(values[4] == "server_port", "redis.res[4] != server_port, unexpected.");
assert!(values[6] == "assigned_groups_num", "redis.res[6] != assigned_groups_num, unexpected.");
assert!(values[8] == "capacity", "redis.res[8] != capacity, unexpected.");
assert!(values[10] == "threshold", "redis.res[10] != threshold, unexpected.");
CGWREDISDBShard {
id: values[1].parse::<i32>().unwrap(),
server_ip: values[3].parse::<IpAddr>().unwrap(),
server_port: values[5].parse::<u16>().unwrap(),
assigned_groups_num: values[7].parse::<i32>().unwrap(),
capacity: values[9].parse::<i32>().unwrap(),
threshold: values[11].parse::<i32>().unwrap(),
}
}
}
impl Into<Vec<String>> for CGWREDISDBShard {
fn into(self) -> Vec<String> {
vec!["id".to_string(), self.id.to_string(),
"server_ip".to_string(), self.server_ip.to_string(),
"server_port".to_string(), self.server_port.to_string(),
"assigned_groups_num".to_string(), self.assigned_groups_num.to_string(),
"capacity".to_string(), self.capacity.to_string(),
"threshold".to_string(), self.threshold.to_string()]
}
}
#[derive(Clone)]
pub struct CGWRemoteConfig {
pub remote_id: i32,
pub server_ip: Ipv4Addr,
pub server_port: u16,
}
impl CGWRemoteConfig {
pub fn new(id: i32, ip_conf: Ipv4Addr, port: u16) -> Self {
CGWRemoteConfig {
remote_id: id,
server_ip: ip_conf,
server_port: port,
}
}
pub fn to_socket_addr(&self) -> SocketAddr {
SocketAddr::new(std::net::IpAddr::V4(self.server_ip), self.server_port)
}
}
#[derive(Clone)]
pub struct CGWRemoteIface {
pub shard: CGWREDISDBShard,
client: CGWRemoteClient,
}
#[derive(Clone)]
pub struct CGWRemoteDiscovery {
db_accessor: Arc<CGWDBAccessor>,
redis_client: redis_async::client::paired::PairedConnection,
gid_to_cgw_cache: Arc<RwLock::<HashMap<i32, i32>>>,
remote_cgws_map: Arc<RwLock::<HashMap<i32, CGWRemoteIface>>>,
local_shard_id: i32,
}
impl CGWRemoteDiscovery {
pub async fn new(app_args: &AppArgs) -> Self {
let rc = CGWRemoteDiscovery {
db_accessor: Arc::new(CGWDBAccessor::new(app_args).await),
redis_client: redis_async::client::paired::paired_connect(
app_args.redis_db_ip.to_string(),
app_args.redis_db_port).await.unwrap(),
gid_to_cgw_cache: Arc::new(RwLock::new(HashMap::new())),
local_shard_id: app_args.cgw_id,
remote_cgws_map: Arc::new(RwLock::new(HashMap::new())),
};
let _ = rc.sync_gid_to_cgw_map().await;
let _ = rc.sync_remote_cgw_map().await;
if let None = rc.remote_cgws_map.read().await.get(&rc.local_shard_id) {
let redisdb_shard_info = CGWREDISDBShard {
id: app_args.cgw_id,
server_ip: std::net::IpAddr::V4(app_args.grpc_ip.clone()),
server_port: u16::try_from(app_args.grpc_port).unwrap(),
assigned_groups_num: 0i32,
capacity: 1000i32,
threshold: 50i32,
};
let redis_req_data: Vec<String> = redisdb_shard_info.into();
let _ = rc.redis_client.send::<i32>(
resp_array!["DEL",
format!("{REDIS_KEY_SHARD_ID_PREFIX}{}", app_args.cgw_id)]).await;
if let Err(e) = rc.redis_client.send::<String>(
resp_array!["HSET", format!("{REDIS_KEY_SHARD_ID_PREFIX}{}", app_args.cgw_id)]
.append(redis_req_data)).await {
panic!("Failed to create record about shard in REDIS, e:{e}");
}
let _ = rc.sync_remote_cgw_map().await;
}
debug!("Found {} remote CGWs:", rc.remote_cgws_map.read().await.len() - 1);
for (key, val) in rc.remote_cgws_map.read().await.iter() {
if val.shard.id == rc.local_shard_id {
continue;
}
debug!("Shard #{}, IP {}:{}", val.shard.id, val.shard.server_ip, val.shard.server_port);
}
rc
}
pub async fn sync_gid_to_cgw_map(&self) {
let mut lock = self.gid_to_cgw_cache.write().await;
// Clear hashmap
lock.clear();
let redis_keys: Vec<String> =
match self.redis_client.send::<Vec<String>>(
resp_array!["KEYS", format!("{}*", REDIS_KEY_GID)]).await {
Err(e) => {
panic!("Failed to get KEYS list from REDIS, e:{e}");
},
Ok(r) => r
};
for key in redis_keys {
let gid : i32 = match self.redis_client.send::<String>(
resp_array!["HGET", &key, REDIS_KEY_GID_VALUE_GID]).await {
Ok(res) => {
match res.parse::<i32>() {
Ok(res) => res,
Err(e) => {
warn!("Found proper key '{key}' entry, but failed to parse GID from it:\n{e}");
continue;
}
}
},
Err(e) => {
warn!("Found proper key '{key}' entry, but failed to fetch GID from it:\n{e}");
continue;
}
};
let shard_id : i32 = match self.redis_client.send::<String>(
resp_array!["HGET", &key, REDIS_KEY_GID_VALUE_SHARD_ID]).await {
Ok(res) => {
match res.parse::<i32>() {
Ok(res) => res,
Err(e) => {
warn!("Found proper key '{key}' entry, but failed to parse SHARD_ID from it:\n{e}");
continue;
}
}
},
Err(e) => {
warn!("Found proper key '{key}' entry, but failed to fetch SHARD_ID from it:\n{e}");
continue;
}
};
debug!("Found group {key}, gid: {gid}, shard_id: {shard_id}");
match lock.insert(gid, shard_id) {
None => continue,
Some(v) => warn!("Populated gid_to_cgw_map with previous value being alerady set, unexpected")
}
}
debug!("Found total {} groups with their respective owners", lock.len());
}
async fn sync_remote_cgw_map(&self) -> Result<(), &'static str> {
let mut lock = self.remote_cgws_map.write().await;
// Clear hashmap
lock.clear();
let redis_keys: Vec<String> =
match self.redis_client.send::<Vec<String>>(
resp_array!["KEYS", format!("{}*", REDIS_KEY_SHARD_ID_PREFIX)]).await {
Err(e) => {
warn!("Failed to get cgw shard KEYS list from REDIS, e:{e}");
return Err("Remote CGW Shards list fetch from REDIS failed");
},
Ok(r) => r
};
for key in redis_keys {
match self.redis_client.send::<Vec<String>>(
resp_array!["HGETALL", &key]).await {
Ok(res) => {
let shrd: CGWREDISDBShard = match CGWREDISDBShard::try_from(res) {
Ok(v) => v,
Err(e) => {
warn!("Failed to parse CGWREDISDBShard, {key}");
continue;
}
};
let endpoint_str =
String::from("http://") +
&shrd.server_ip.to_string() +
":" +
&shrd.server_port.to_string();
let cgw_iface = CGWRemoteIface {
shard: shrd,
client: CGWRemoteClient::new(endpoint_str),
};
lock.insert(cgw_iface.shard.id, cgw_iface);
},
Err(e) => {
warn!("Found proper key '{key}' entry, but failed to fetch Shard info from it:\n{e}");
continue;
}
}
}
Ok(())
}
pub async fn get_infra_group_owner_id(&self, gid: i32) -> Option<i32> {
// try to use internal cache first
if let Some(id) = self.gid_to_cgw_cache.read().await.get(&gid) {
return Some(*id);
}
// then try to use redis
self.sync_gid_to_cgw_map().await;
if let Some(id) = self.gid_to_cgw_cache.read().await.get(&gid) {
return Some(*id);
}
None
}
async fn increment_cgw_assigned_groups_num(&self, id: i32) -> Result<(), &'static str> {
debug!("inc {id}");
/*
if let Err(e) = self.redis_client.send::<i32>(
*/
match self.redis_client.send::<i32>(
resp_array!["HINCRBY",
format!("{}{id}", REDIS_KEY_SHARD_ID_PREFIX),
REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM,
"1"]).await {
Ok(v) => {
debug!("ret {:?}", v);
},
Err(e) => {
warn!("Failed to increment CGW{id} assigned group num count, e:{e}");
return Err("Failed to increment assigned group num count");
}
}
Ok(())
}
async fn decrement_cgw_assigned_groups_num(&self, id: i32) -> Result<(), &'static str> {
if let Err(e) = self.redis_client.send::<i32>(
resp_array!["HINCRBY",
format!("{}{id}", REDIS_KEY_SHARD_ID_PREFIX),
REDIS_KEY_SHARD_VALUE_ASSIGNED_G_NUM,
"-1"]).await {
warn!("Failed to decrement CGW{id} assigned group num count, e:{e}");
return Err("Failed to decrement assigned group num count");
}
Ok(())
}
async fn get_infra_group_cgw_assignee(&self) -> Result<i32, &'static str> {
let lock = self.remote_cgws_map.read().await;
let mut hash_vec: Vec<(&i32, &CGWRemoteIface)> = lock.iter().collect();
hash_vec.sort_by(|a, b| b.1.shard.assigned_groups_num.cmp(&a.1.shard.assigned_groups_num));
for x in hash_vec {
debug!("id_{} capacity {} t {} assigned {}",
x.1.shard.id, x.1.shard.capacity,
x.1.shard.threshold,
x.1.shard.assigned_groups_num);
let max_capacity: i32 = x.1.shard.capacity + x.1.shard.threshold;
if x.1.shard.assigned_groups_num + 1 <= max_capacity {
debug!("Found CGW shard to assign group to (id {})", x.1.shard.id);
return Ok(x.1.shard.id);
}
}
warn!("Every available CGW is exceeding capacity+threshold limit, using least loaded one...");
if let Some(least_loaded_cgw) = lock.iter()
.min_by(|a, b| a.1.shard.assigned_groups_num.cmp(&b.1.shard.assigned_groups_num))
.map(|(k, _v)| _v) {
warn!("Found least loaded CGW id: {}", least_loaded_cgw.shard.id);
return Ok(least_loaded_cgw.shard.id);
}
return Err("Unexpected: Failed to find the least loaded CGW shard");
}
async fn assign_infra_group_to_cgw(&self, gid: i32) -> Result<i32, &'static str> {
// Delete key (if exists), recreate with new owner
let _ = self.deassign_infra_group_to_cgw(gid).await;
let dst_cgw_id: i32 = match self.get_infra_group_cgw_assignee().await {
Ok(v) => v,
Err(e) => {
warn!("Failed to assign {gid} to any shard, reason:{e}");
return Err(e);
}
};
if let Err(e) = self.redis_client.send::<String>(
resp_array!["HSET",
format!("{REDIS_KEY_GID}{gid}"),
REDIS_KEY_GID_VALUE_GID,
gid.to_string(),
REDIS_KEY_GID_VALUE_SHARD_ID,
dst_cgw_id.to_string()]).await {
error!("Failed to update REDIS gid{gid} owner to shard{dst_cgw_id}, e:{e}");
return Err("Hot-cache (REDIS DB) update owner failed");
}
let mut lock = self.gid_to_cgw_cache.write().await;
lock.insert(gid, dst_cgw_id);
debug!("REDIS: assigned gid{gid} to shard{dst_cgw_id}");
Ok(dst_cgw_id)
}
pub async fn deassign_infra_group_to_cgw(&self, gid: i32) -> Result<(), &'static str> {
if let Err(e) = self.redis_client.send::<i64>(
resp_array!["DEL", format!("{REDIS_KEY_GID}{gid}")]).await {
error!("Failed to deassigned REDIS gid{gid} owner, e:{e}");
return Err("Hot-cache (REDIS DB) deassign owner failed");
}
debug!("REDIS: deassigned gid{gid} from controlled CGW");
let mut lock = self.gid_to_cgw_cache.write().await;
lock.remove(&gid);
Ok(())
}
pub async fn create_infra_group(&self, g: &CGWDBInfrastructureGroup) -> Result<i32, &'static str> {
//TODO: transaction-based insert/assigned_group_num update (DB)
let rc = self.db_accessor.insert_new_infra_group(g).await;
if let Err(e) = rc {
return Err(e);
}
let shard_id: i32 = match self.assign_infra_group_to_cgw(g.id).await {
Ok(v) => v,
Err(e) => {
let _ = self.db_accessor.delete_infra_group(g.id).await;
return Err("Assign group to CGW shard failed");
}
};
let rc = self.increment_cgw_assigned_groups_num(shard_id).await;
if let Err(e) = rc {
return Err(e);
}
Ok(shard_id)
}
pub async fn destroy_infra_group(&self, gid: i32) -> Result<(), &'static str> {
let cgw_id: Option<i32> = self.get_infra_group_owner_id(gid).await;
if let Some(id) = cgw_id {
let _ = self.deassign_infra_group_to_cgw(gid).await;
let _ = self.decrement_cgw_assigned_groups_num(id).await;
}
//TODO: transaction-based insert/assigned_group_num update (DB)
let rc = self.db_accessor.delete_infra_group(gid).await;
if let Err(e) = rc {
return Err(e);
}
Ok(())
}
pub async fn create_ifras_list(&self, gid: i32, infras: Vec<String>) -> Result<(), Vec<String>> {
// TODO: assign list to shards; currently - only created bulk, no assignment
let mut futures = Vec::with_capacity(infras.len());
// Results store vec of MACs we failed to add
let mut failed_infras: Vec<String> = Vec::with_capacity(futures.len());
for x in infras.iter() {
let db_accessor_clone = self.db_accessor.clone();
let infra = CGWDBInfra {
mac: MacAddress::parse_str(&x).unwrap(),
infra_group_id: gid,
};
futures.push(
tokio::spawn(
async move {
if let Err(_) = db_accessor_clone.insert_new_infra(&infra).await {
Err(infra.mac.to_string(eui48::MacAddressFormat::HexString))
} else {
Ok(())
}
}));
}
for (i, future) in futures.iter_mut().enumerate() {
match future.await {
Ok(res) => {
if let Err(mac) = res {
failed_infras.push(mac);
}
},
Err(_) => {
failed_infras.push(infras[i].clone());
}
}
}
if failed_infras.len() > 0 {
return Err(failed_infras);
}
Ok(())
}
pub async fn destroy_ifras_list(&self, gid: i32, infras: Vec<String>) -> Result<(), Vec<String>> {
let mut futures = Vec::with_capacity(infras.len());
// Results store vec of MACs we failed to add
let mut failed_infras: Vec<String> = Vec::with_capacity(futures.len());
for x in infras.iter() {
let db_accessor_clone = self.db_accessor.clone();
let mac = MacAddress::parse_str(&x).unwrap();
futures.push(
tokio::spawn(
async move {
if let Err(_) = db_accessor_clone.delete_infra(mac).await {
Err(mac.to_string(eui48::MacAddressFormat::HexString))
} else {
Ok(())
}
}));
}
for (i, future) in futures.iter_mut().enumerate() {
match future.await {
Ok(res) => {
if let Err(mac) = res {
failed_infras.push(mac);
}
},
Err(_) => {
failed_infras.push(infras[i].clone());
}
}
}
if failed_infras.len() > 0 {
return Err(failed_infras);
}
Ok(())
}
pub async fn relay_request_stream_to_remote_cgw(
&self,
shard_id: i32,
stream: Vec<(String, String)>)
-> Result<(), ()> {
// try to use internal cache first
if let Some(cl) = self.remote_cgws_map.read().await.get(&shard_id) {
if let Err(()) = cl.client.relay_request_stream(stream).await {
return Err(())
}
return Ok(());
}
// then try to use redis
let _ = self.sync_remote_cgw_map().await;
if let Some(cl) = self.remote_cgws_map.read().await.get(&shard_id) {
if let Err(()) = cl.client.relay_request_stream(stream).await {
return Err(())
}
return Ok(());
}
error!("No suitable CGW instance #{shard_id} was discovered, cannot relay msg");
return Err(());
}
}

93
src/cgw_remote_server.rs Normal file
View File

@@ -0,0 +1,93 @@
use crate::{
AppArgs,
};
pub mod cgw_remote {
tonic::include_proto!("cgw.remote");
}
use tonic::{transport::Server, Request, Response, Status};
use cgw_remote::{
EnqueueRequest,
EnqueueResponse,
remote_server::{
RemoteServer,
Remote,
},
};
use tokio::{
time::{
Duration,
},
};
use crate::cgw_remote_discovery::{
CGWRemoteConfig,
};
use crate::cgw_connection_server:: {
CGWConnectionServer,
};
use tokio_stream::StreamExt;
use std::{
sync::{
Arc,
},
};
struct CGWRemote {
cgw_srv: Arc<CGWConnectionServer>,
}
#[tonic::async_trait]
impl Remote for CGWRemote {
async fn enqueue_nbapi_request_stream(&self, request: Request<tonic::Streaming<EnqueueRequest>>, ) -> Result<Response<EnqueueResponse>, Status> {
let mut rq_stream = request.into_inner();
while let Some(rq) = rq_stream.next().await {
let rq = rq.unwrap();
self.cgw_srv.enqueue_mbox_relayed_message_to_cgw_server(rq.key, rq.req).await;
}
let reply = cgw_remote::EnqueueResponse {
ret: 0,
msg: "DONE".to_string(),
};
Ok(Response::new(reply))
}
}
pub struct CGWRemoteServer {
cfg: CGWRemoteConfig,
}
impl CGWRemoteServer {
pub fn new(app_args: &AppArgs) -> Self {
let remote_cfg = CGWRemoteConfig::new(app_args.cgw_id, app_args.grpc_ip, app_args.grpc_port);
let remote_server = CGWRemoteServer {
cfg: remote_cfg,
};
remote_server
}
pub async fn start(&self, srv: Arc<CGWConnectionServer>) {
// GRPC server
// TODO: use CGWRemoteServerConfig wrap
let icgw_serv = CGWRemote{
cgw_srv: srv,
};
let grpc_srv = Server::builder()
.tcp_keepalive(Some(Duration::from_secs(7200)))
.http2_keepalive_timeout(Some(Duration::from_secs(7200)))
.http2_keepalive_interval(Some(Duration::from_secs(7200)))
.http2_max_pending_accept_reset_streams(Some(1000))
.add_service(RemoteServer::new(icgw_serv));
info!("Starting GRPC server id {} - listening at {}:{}", self.cfg.remote_id, self.cfg.server_ip, self.cfg.server_port);
let res = grpc_srv.serve(self.cfg.to_socket_addr()).await;
error!("grpc server returned {:?}", res);
// end of GRPC server build / start declaration
}
}

24
src/localhost.crt Executable file
View File

@@ -0,0 +1,24 @@
-----BEGIN CERTIFICATE-----
MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u
eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTE2MDgxMzE2MDcwNFoX
DTIyMDIwMzE2MDcwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCpVhh1/FNP2qvWenbZSghari/UThwe
dynfnHG7gc3JmygkEdErWBO/CHzHgsx7biVE5b8sZYNEDKFojyoPHGWK2bQM/FTy
niJCgNCLdn6hUqqxLAml3cxGW77hAWu94THDGB1qFe+eFiAUnDmob8gNZtAzT6Ky
b/JGJdrEU0wj+Rd7wUb4kpLInNH/Jc+oz2ii2AjNbGOZXnRz7h7Kv3sO9vABByYe
LcCj3qnhejHMqVhbAT1MD6zQ2+YKBjE52MsQKU/xhUpu9KkUyLh0cxkh3zrFiKh4
Vuvtc+n7aeOv2jJmOl1dr0XLlSHBlmoKqH6dCTSbddQLmlK7dms8vE01AgMBAAGj
gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFMeUzGYV
bXwJNQVbY1+A8YXYZY8pMEIGA1UdIwQ7MDmAFJvEsUi7+D8vp8xcWvnEdVBGkpoW
oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO
dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0
MA0GCSqGSIb3DQEBCwUAA4IBgQBsk5ivAaRAcNgjc7LEiWXFkMg703AqDDNx7kB1
RDgLalLvrjOfOp2jsDfST7N1tKLBSQ9bMw9X4Jve+j7XXRUthcwuoYTeeo+Cy0/T
1Q78ctoX74E2nB958zwmtRykGrgE/6JAJDwGcgpY9kBPycGxTlCN926uGxHsDwVs
98cL6ZXptMLTR6T2XP36dAJZuOICSqmCSbFR8knc/gjUO36rXTxhwci8iDbmEVaf
BHpgBXGU5+SQ+QM++v6bHGf4LNQC5NZ4e4xvGax8ioYu/BRsB/T3Lx+RlItz4zdU
XuxCNcm3nhQV2ZHquRdbSdoyIxV5kJXel4wCmOhWIq7A2OBKdu5fQzIAzzLi65EN
RPAKsKB4h7hGgvciZQ7dsMrlGw0DLdJ6UrFyiR5Io7dXYT/+JP91lP5xsl6Lhg9O
FgALt7GSYRm2cZdgi9pO9rRr83Br1VjQT1vHz6yoZMXSqc4A2zcN2a2ZVq//rHvc
FZygs8miAhWPzqnpmgTj1cPiU1M=
-----END CERTIFICATE-----

275
src/main.rs Normal file
View File

@@ -0,0 +1,275 @@
#![warn(rust_2018_idioms)]
mod cgw_connection_server;
mod cgw_nb_api_listener;
mod cgw_connection_processor;
mod cgw_remote_discovery;
mod cgw_remote_server;
mod cgw_remote_client;
mod cgw_db_accessor;
#[macro_use]
extern crate log;
use tokio::{
net::{
TcpListener,
},
time::{
sleep,
Duration,
},
runtime::{
Builder,
Handle,
Runtime,
},
};
use native_tls::Identity;
use std::{
net::{
Ipv4Addr,
SocketAddr,
},
sync::{
Arc,
},
};
use rlimit::{
setrlimit,
Resource,
};
use cgw_connection_server::{
CGWConnectionServer,
};
use cgw_remote_server::{
CGWRemoteServer,
};
use clap::{
ValueEnum,
Parser,
};
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
enum AppCoreLogLevel {
/// Print debug-level messages and above
Debug,
/// Print info-level messages and above
///
///
Info,
}
/// CGW server
#[derive(Parser, Clone)]
#[command(version, about, long_about = None)]
pub struct AppArgs {
/// CGW unique identifier (u64)
#[arg(short, long, default_value_t = 0)]
cgw_id: i32,
/// Number of thread in a threadpool dedicated for handling secure websocket connections
#[arg(short, long, default_value_t = 4)]
wss_t_num: usize,
/// Loglevel of application
#[arg(value_enum, default_value_t = AppCoreLogLevel::Debug)]
log_level: AppCoreLogLevel,
/// IP to listen for incoming WSS connection
#[arg(long, default_value_t = Ipv4Addr::new(0, 0, 0, 0))]
wss_ip: Ipv4Addr,
/// PORT to listen for incoming WSS connection
#[arg(long, default_value_t = 15002)]
wss_port: u16,
/// IP to listen for incoming GRPC connection
#[arg(long, default_value_t = Ipv4Addr::new(0, 0, 0, 0))]
grpc_ip: Ipv4Addr,
/// PORT to listen for incoming GRPC connection
#[arg(long, default_value_t = 50051)]
grpc_port: u16,
/// IP to connect to KAFKA broker
#[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))]
kafka_ip: Ipv4Addr,
/// PORT to connect to KAFKA broker
#[arg(long, default_value_t = 9092)]
kafka_port: u16,
/// KAFKA topic from where to consume messages
#[arg(long, default_value_t = String::from("CnC"))]
kafka_consume_topic: String,
/// KAFKA topic where to produce messages
#[arg(long, default_value_t = String::from("CnC_Res"))]
kafka_produce_topic: String,
/// IP to connect to DB (PSQL)
#[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))]
db_ip: Ipv4Addr,
/// PORT to connect to DB (PSQL)
#[arg(long, default_value_t = 5432)]
db_port: u16,
/// DB name to connect to in DB (PSQL)
#[arg(long, default_value_t = String::from("cgw"))]
db_name: String,
/// DB user name use with connection to in DB (PSQL)
#[arg(long, default_value_t = String::from("cgw"))]
db_username: String,
/// DB user password use with connection to in DB (PSQL)
#[arg(long, default_value_t = String::from("123"))]
db_password: String,
/// IP to connect to DB (REDIS)
#[arg(long, default_value_t = Ipv4Addr::new(127, 0, 0, 1))]
redis_db_ip: Ipv4Addr,
/// PORT to connect to DB (REDIS)
#[arg(long, default_value_t = 6379)]
redis_db_port: u16,
}
pub struct AppCore {
cgw_server: Arc<CGWConnectionServer>,
main_runtime_handle: Arc<Handle>,
grpc_server_runtime_handle: Arc<Runtime>,
conn_ack_runtime_handle: Arc<Runtime>,
args: AppArgs,
}
impl AppCore {
async fn new(app_args: AppArgs) -> Self {
Self::setup_app(&app_args);
let current_runtime = Arc::new(Handle::current());
let c_ack_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(1)
.thread_name("cgw-c-ack")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let rpc_runtime_handle = Arc::new(Builder::new_multi_thread()
.worker_threads(1)
.thread_name("grpc-recv-t")
.thread_stack_size(1 * 1024 * 1024)
.enable_all()
.build()
.unwrap());
let app_core = AppCore {
cgw_server: CGWConnectionServer::new(&app_args).await,
main_runtime_handle: current_runtime,
conn_ack_runtime_handle: c_ack_runtime_handle,
args: app_args,
grpc_server_runtime_handle: rpc_runtime_handle
};
app_core
}
fn setup_app(args: &AppArgs) {
let nofile_rlimit = Resource::NOFILE.get().unwrap();
println!("{:?}", nofile_rlimit);
let nofile_hard_limit = nofile_rlimit.1;
assert!(setrlimit(Resource::NOFILE, nofile_hard_limit, nofile_hard_limit).is_ok());
let nofile_rlimit = Resource::NOFILE.get().unwrap();
println!("{:?}", nofile_rlimit);
match args.log_level {
AppCoreLogLevel::Debug => ::std::env::set_var("RUST_LOG", "ucentral_cgw=debug"),
AppCoreLogLevel::Info => ::std::env::set_var("RUST_LOG", "ucentral_cgw=info"),
}
env_logger::init();
}
async fn run(self: Arc<AppCore>) {
let main_runtime_handle = self.main_runtime_handle.clone();
let core_clone = self.clone();
let cgw_remote_server = CGWRemoteServer::new(&self.args);
let cgw_srv_clone = self.cgw_server.clone();
self.grpc_server_runtime_handle.spawn(async move {
debug!("cgw_remote_server.start entry");
cgw_remote_server.start(cgw_srv_clone).await;
debug!("cgw_remote_server.start exit");
});
main_runtime_handle.spawn(async move {
server_loop(core_clone).await
});
// TODO:
// Add signal processing and etcetera app-related handlers.
loop {
sleep(Duration::from_millis(5000)).await;
}
}
}
// TODO: a method of an object (TlsAcceptor? CGWConnectionServer?), not a plain function
async fn server_loop(app_core: Arc<AppCore>) -> () {
debug!("sever_loop entry");
debug!("Starting WSS server, listening at {}:{}", app_core.args.wss_ip, app_core.args.wss_port);
// Bind the server's socket
let sockaddraddr = SocketAddr::new(std::net::IpAddr::V4(app_core.args.wss_ip), app_core.args.wss_port);
let listener: Arc<TcpListener> = match TcpListener::bind(sockaddraddr).await {
Ok(listener) => Arc::new(listener),
Err(e) => panic!("listener bind failed {e}"),
};
info!("Started WSS server.");
// Create the TLS acceptor.
// TODO: custom acceptor
let der = include_bytes!("localhost.crt");
let key = include_bytes!("localhost.key");
let cert = match Identity::from_pkcs8(der, key) {
Ok(cert) => cert,
Err(e) => panic!("Cannot create SSL identity from supplied cert\n{e}")
};
let tls_acceptor =
tokio_native_tls::TlsAcceptor::from(
match native_tls::TlsAcceptor::builder(cert).build() {
Ok(builder) => builder,
Err(e) => panic!("Cannot create SSL-acceptor from supplied cert\n{e}")
});
// Spawn explicitly in main thread: created task accepts connection,
// but handling is spawned inside another threadpool runtime
let app_core_clone = app_core.clone();
let _ = app_core.main_runtime_handle.spawn(async move {
let mut conn_idx: i64 = 0;
loop {
let app_core_clone = app_core_clone.clone();
let cgw_server_clone = app_core_clone.cgw_server.clone();
let tls_acceptor_clone = tls_acceptor.clone();
// Asynchronously wait for an inbound socket.
let (socket, remote_addr) = match listener.accept().await {
Ok((sock, addr)) => (sock, addr),
Err(e) => {
error!("Failed to Accept conn {e}\n");
continue;
}
};
// TODO: we control tls_acceptor, thus at this stage it's our responsibility
// to provide underlying certificates common name inside the ack_connection func
// (CN == mac address of device)
app_core_clone.conn_ack_runtime_handle.spawn(async move {
cgw_server_clone.ack_connection(socket, tls_acceptor_clone, remote_addr, conn_idx).await;
});
conn_idx += 1;
}
}).await;
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let args = AppArgs::parse();
let app = Arc::new(AppCore::new(args).await);
app.run().await;
}

17
src/proto/cgw.proto Normal file
View File

@@ -0,0 +1,17 @@
syntax = "proto3";
package cgw.remote;
service Remote {
rpc EnqueueNBAPIRequestStream(stream EnqueueRequest) returns (EnqueueResponse);
}
message EnqueueRequest {
string req = 1;
string key = 2;
}
message EnqueueResponse {
uint64 ret = 1;
string msg = 2;
}

1
utils/cert_generator/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
certs/

View File

@@ -0,0 +1,18 @@
# Generate Certificates
```
# generate CA certs
$ ./generate_certs.sh -a
# generate server certs
$ ./generate_certs.sh -s
# generate 10 client certs
$ ./generate_certs.sh -c 10
# generate 10 client certs with MAC addr AA:*
$ ./generate_certs.sh -c 10 -m AA:XX:XX:XX:XX:XX
```
The certificates will be available in `./certs/`.

View File

@@ -0,0 +1,19 @@
[req]
distinguished_name=req
[ca]
basicConstraints=CA:true
keyUsage=keyCertSign,cRLSign
subjectKeyIdentifier=hash
[client]
basicConstraints=CA:FALSE
extendedKeyUsage=clientAuth
keyUsage=digitalSignature,keyEncipherment
[server]
basicConstraints=CA:FALSE
extendedKeyUsage=serverAuth
keyUsage=digitalSignature,keyEncipherment
subjectAltName=@sans
[sans]
DNS.1=localhost
IP.1=127.0.0.1

View File

@@ -0,0 +1,161 @@
#!/bin/bash
TEMPLATE="02:XX:XX:XX:XX:XX"
HEXCHARS="0123456789ABCDEF"
CONF_FILE=ca.conf
CERT_DIR=certs
CA_DIR=$CERT_DIR/ca
CA_CERT=$CA_DIR/ca.crt
CA_KEY=$CA_DIR/ca.key
SERVER_DIR=$CERT_DIR/server
CLIENT_DIR=$CERT_DIR/client
METHOD_FAST=y
usage()
{
echo "Usage: $0 [options]"
echo
echo "options:"
echo "-h"
echo -e "\tprint this help"
echo "-a"
echo -e "\tgenerate CA key and certificate"
echo "-s"
echo -e "\tgenerate server key and certificate; sign the certificate"
echo -e "\tusing the CA certificate"
echo "-c NUMBER"
echo -e "\tgenerate *NUMBER* of client keys and certificates; sign"
echo -e "\tall of the certificates using the CA certificate"
echo "-m MASK"
echo -e "\tspecify custom MAC addr mask"
echo "-o"
echo -e "\tslow mode"
}
rand_mac()
{
local MAC=$TEMPLATE
while [[ $MAC =~ "X" || $MAC =~ "x" ]]
do
MAC=${MAC/[xX]/${HEXCHARS:$(( $RANDOM % 16 )):1}}
done
echo $MAC
}
gen_cert()
{
local req_file=$(mktemp)
local pem=$(mktemp)
local type=$1
local cn=$2
local key=$3
local cert=$4
# generate key and request to sign
openssl req -config $CONF_FILE -x509 -nodes -newkey rsa:4096 -sha512 -days 365 \
-extensions $type -subj "/CN=$cn" -out $req_file -keyout $key &> /dev/null
# sign certificate
openssl x509 -extfile $CONF_FILE -CA $CA_CERT -CAkey $CA_KEY -CAcreateserial -sha512 -days 365 \
-in $req_file -out $pem
if [ $? == "0" ]
then
cat $pem $CA_CERT > $cert
else
>&2 echo Failed to generate certificate
rm $key
fi
rm $req_file
rm $pem
}
gen_client_batch()
{
batch=$(($1 / 100))
sync_count=$(($1 / 10))
baseline=$(ps aux | grep openssl | wc -l)
for (( c=1; c<=$1; c++ ))
do
mac=$(rand_mac)
gen_cert client $mac $CLIENT_DIR/$mac.key $CLIENT_DIR/$mac.crt &
if [ "$(( $c % $batch ))" -eq "0" ]
then
echo $(($c/$batch))%
fi
if [ "$(( $c % $sync_count ))" -eq "0" ]
then
until [ $(ps aux | grep openssl | wc -l) -eq "$baseline" ]; do sleep 1; done
fi
done
}
gen_client()
{
for x in $(seq $1)
do
echo $x
mac=$(rand_mac)
gen_cert client $mac $CLIENT_DIR/$mac.key $CLIENT_DIR/$mac.crt
done
}
while getopts "ac:shm:o" arg; do
case $arg in
a)
GEN_CA=y
;;
s)
GEN_SER=y
;;
c)
GEN_CLI=y
NUM_CERTS=$OPTARG
if [ $NUM_CERTS -lt 100 ]
then
METHOD_FAST=n
fi
;;
m)
TEMPLATE=$OPTARG
;;
o)
METHOD_FAST=n
;;
h)
usage
exit 0
;;
*)
usage
exit 1
;;
esac
done
if [ "$GEN_CA" == "y" ]
then
echo Generating root CA certificate
mkdir -p $CA_DIR
openssl req -config $CONF_FILE -x509 -nodes -newkey rsa:4096 -sha512 -days 365 \
-extensions ca -subj "/CN=CA" -out $CA_CERT -keyout $CA_KEY &> /dev/null
fi
if [ "$GEN_SER" == "y" ]
then
mkdir -p $SERVER_DIR
echo Generating server certificate
gen_cert server localhost $SERVER_DIR/gw.key $SERVER_DIR/gw.crt
fi
if [ "$GEN_CLI" == "y" ]
then
echo Generating $NUM_CERTS client certificates
mkdir -p $CLIENT_DIR
if [ $METHOD_FAST == "y" ]
then
# because of race condition some of the certificates might fail to generate
# but this is ~15 times faster than generating certificates one by one
gen_client_batch $NUM_CERTS
else
gen_client $NUM_CERTS
fi
fi
echo done

6
utils/client_simulator/.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
bin/
lib/
lib64
include/
*.cfg
__pycache__

View File

@@ -0,0 +1,10 @@
FROM ubuntu
COPY ./requirements.txt /tmp/requirements.txt
WORKDIR /opt/client_simulator
RUN apt-get update -q -y && apt-get -q -y --no-install-recommends install \
python3 \
python3-pip
RUN python3 -m pip install -r /tmp/requirements.txt

View File

@@ -0,0 +1,34 @@
IMG_NAME=cgw-client-sim
CONTAINER_NAME=cgw_client_sim
MAC?=XX:XX:XX:XX:XX:XX
COUNT?=1000
URL=wss://localhost:15002
CA_CERT_PATH?=$(PWD)/../cert_generator/certs/ca
CLIENT_CERT_PATH?=$(PWD)/../cert_generator/certs/client
MSG_INTERVAL?=10
MSG_SIZE?=1000
.PHONY: build spawn stop start
build:
docker build -t ${IMG_NAME} .
spawn:
docker run --name "${CONTAINER_NAME}_${COUNT}_$(subst :,-,$(MAC))" \
-d --rm --network host \
-v $(PWD):/opt/client_simulator \
-v ${CA_CERT_PATH}:/etc/ca \
-v ${CLIENT_CERT_PATH}:/etc/certs \
${IMG_NAME} \
python3 main.py -M ${MAC} -N ${COUNT} -s ${URL} \
--ca-cert /etc/ca/ca.crt \
--client-certs-path /etc/certs \
--msg-interval ${MSG_INTERVAL} \
--payload-size ${MSG_SIZE} \
--wait-for-signal
stop:
docker stop $$(docker ps -q -f name=$(CONTAINER_NAME))
start:
docker kill -s SIGUSR1 $$(docker ps -q -f name=$(CONTAINER_NAME))

View File

@@ -0,0 +1,39 @@
# Run client simulation
```
# run 10 concurrent client simulations
$ ./main.py -s wss://localhost:50001 -n 10
# use only specified MAC addrs
$ ./main.py -s wss://localhost:15002 -N 10 -M AA:XX:XX:XX:XX:XX
# run 10 concurrent simulations with MAC AA:* and 10 concurrent simulations with MAC BB:*
# AA:* and BB:* simulations are run in separate processes
$ ./main.py -s wss://localhost:15002 -N 10 -M AA:XX:XX:XX:XX:XX -M BB:XX:XX:XX:XX:XX
```
To stop the simulation use `Ctrl+C`.
# Run simulation in docker
```
$ make
$ make spawn
$ make start
$ make stop
# specify mac addr range
$ make spawn MAC=11:22:AA:BB:XX:XX
# specify number of client connections (default is 1000)
$ make spawn COUNT=100
# specify server url
$ make spawn URL=wss://localhost:15002
# tell all running containers to start connecting to the server
$ make start
# stop all containers
$ make stop
```

View File

@@ -0,0 +1 @@
../cert_generator/certs/

View File

@@ -0,0 +1,5 @@
{"connect": {"jsonrpc":"2.0","method":"connect","params":{"serial":"MAC","firmware":"Rel 1.6 build 1","uuid":1692198868,"capabilities":{"compatible":"x86_64-kvm_x86_64-r0","model":"DellEMC-S5248f-P-25G-DPB","platform":"switch","label_macaddr":"MAC"}}},
"state": {"jsonrpc": "2.0", "method": "state", "params": {"serial": "MAC","uuid": 1692198868, "request_uuid": null, "state": {}}},
"reboot_response": {"jsonrpc": "2.0", "result": {"serial": "MAC", "status": {"error": 0, "text": "", "when": 0}, "id": "ID"}},
"log": {"jsonrpc": "2.0", "method": "log", "params": {"serial": "MAC", "log": "", "severity": 7, "data": {}}}
}

8
utils/client_simulator/main.py Executable file
View File

@@ -0,0 +1,8 @@
#!/usr/bin/env python3
from src.utils import parse_args
from src.simulation_runner import main
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1 @@
websockets==12.0

View File

View File

@@ -0,0 +1,57 @@
import logging
TRACE_LEVEL = logging.DEBUG - 5
TRACE_NAME = "TRACE"
class ColoredFormatter(logging.Formatter):
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
white = "\x1b[37;20m"
cyan = "\x1b[36;20m"
reset = "\x1b[0m"
format = "{asctime}|{levelname}|{threadName}|{funcName}:{lineno}\t{message}"
FORMATS = {
logging.DEBUG: grey + format + reset,
logging.INFO: cyan + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
TRACE_LEVEL: white + format + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, style="{")
return formatter.format(record)
def __trace(self, msg, *args, **kwargs):
"""
Log 'msg % args' with severity 'TRACE'.
logger.trace("periodic log")
"""
if self.isEnabledFor(TRACE_LEVEL):
self._log(TRACE_LEVEL, msg, args, **kwargs)
logging.basicConfig(level=logging.INFO, datefmt="%H:%M:%S")
logging.addLevelName(TRACE_LEVEL, TRACE_NAME)
setattr(logging, TRACE_NAME, TRACE_LEVEL)
setattr(logging.getLoggerClass(), TRACE_NAME.lower(), __trace)
console = logging.StreamHandler()
console.setLevel(logging.NOTSET)
console.setFormatter(ColoredFormatter())
logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(console)
logging.getLogger('websockets.client').setLevel(logging.INFO)

View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python3
from .utils import get_msg_templates, Args
from .log import logger
from websockets.sync import client
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError, ConnectionClosed
from typing import List
import multiprocessing
import threading
import resource
import string
import random
import signal
import copy
import time
import json
import ssl
import os
import re
class Message:
def __init__(self, mac: str, size: int):
self.templates = get_msg_templates()
self.connect = json.dumps(self.templates["connect"]).replace("MAC", mac)
self.state = json.dumps(self.templates["state"]).replace("MAC", mac)
self.reboot_response = json.dumps(self.templates["reboot_response"]).replace("MAC", mac)
self.log = copy.deepcopy(self.templates["log"])
self.log["params"]["data"] = {"msg": ''.join(random.choices(string.ascii_uppercase + string.digits, k=size))}
self.log = json.dumps(self.log).replace("MAC", mac)
@staticmethod
def to_json(msg) -> str:
return json.dumps(msg)
@staticmethod
def from_json(msg) -> dict:
return json.loads(msg)
class Device:
def __init__(self, mac: str, server: str, ca_cert: str,
msg_interval: int, msg_size: int,
client_cert: str, client_key: str,
start_event: multiprocessing.Event,
stop_event: multiprocessing.Event):
self.mac = mac
self.interval = msg_interval
self.messages = Message(self.mac, msg_size)
self.server_addr = server
self.start_event = start_event
self.stop_event = stop_event
self.reboot_time_s = 10
self._socket = None
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.ssl_context.load_cert_chain(client_cert, client_key, "")
self.ssl_context.load_verify_locations(ca_cert)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
def send_hello(self, socket: client.ClientConnection):
logger.debug(self.messages.connect)
socket.send(self.messages.connect)
def send_log(self, socket: client.ClientConnection):
socket.send(self.messages.log)
def handle_messages(self, socket: client.ClientConnection):
try:
msg = socket.recv(self.interval)
msg = self.messages.from_json(msg)
logger.info(msg)
if msg["method"] == "reboot":
self.handle_reboot(socket, msg)
else:
logger.error(f"Unknown method {msg['method']}")
except TimeoutError: # no messages
pass
except (ConnectionClosedOK, ConnectionClosedError, ConnectionClosed):
logger.critical("Did not expect socket to be closed")
raise
def handle_reboot(self, socket: client.ClientConnection, msg: dict):
resp = self.messages.from_json(self.messages.reboot_response)
if "id" in msg:
resp["result"]["id"] = msg["id"]
else:
del resp["result"]["id"]
logger.warn("Reboot request is missing 'id' field")
socket.send(self.messages.to_json(resp))
self.disconnect()
time.sleep(self.reboot_time_s)
self.connect()
self.send_hello(self._socket)
def connect(self):
if self._socket is None:
self._socket = client.connect(self.server_addr, ssl_context=self.ssl_context, open_timeout=7200)
return self._socket
def disconnect(self):
if self._socket is not None:
self._socket.close()
self._socket = None
def job(self):
logger.debug("waiting for start trigger")
self.start_event.wait()
if self.stop_event.is_set():
return
logger.debug("starting simulation")
self.connect()
start = time.time()
try:
self.send_hello(self._socket)
while not self.stop_event.is_set():
if self._socket is None:
logger.error("Connection to GW is lost. Trying to reconnect...")
self.connect()
if time.time() - start > self.interval:
logger.info(f"Sent log")
self.send_log(self._socket)
start = time.time()
self.handle_messages(self._socket)
finally:
self.disconnect()
logger.debug("simulation done")
def get_avail_mac_addrs(path, mask="XX:XX:XX:XX:XX:XX"):
_mask = "".join(("[0-9a-fA-F]" if c == "X" else c) for c in mask.upper())
certs = sorted(os.listdir(path))
macs = set(cert.split(".")[0] for cert in certs if "crt" in cert and re.match(_mask, cert))
return list(macs)
def update_fd_limit():
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
except ValueError:
logger.critical("Failed to update fd limit")
raise
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
logger.warning(f"changed fd limit {soft, hard}")
def process(args: Args, mask: str, start_event: multiprocessing.Event, stop_event: multiprocessing.Event):
signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore Ctrl+C in child processes
threading.current_thread().name = mask
logger.info(f"process started")
macs = get_avail_mac_addrs(args.cert_path, mask)
if len(macs) < args.number_of_connections:
logger.warn(f"expected {args.number_of_connections} certificates, but only found {len(macs)} "
f"({mask = })")
update_fd_limit()
devices = [Device(mac, args.server, args.ca_path, args.msg_interval, args.msg_size,
os.path.join(args.cert_path, f"{mac}.crt"),
os.path.join(args.cert_path, f"{mac}.key"),
start_event, stop_event)
for mac, _ in zip(macs, range(args.number_of_connections))]
threads = [threading.Thread(target=d.job, name=d.mac) for d in devices]
[t.start() for t in threads]
[t.join() for t in threads]
def verify_cert_availability(cert_path: str, masks: List[str], count: int):
for mask in masks:
macs = get_avail_mac_addrs(cert_path, mask)
assert len(macs) >= count, \
f"Simulation requires {count} certificates, but only found {len(macs)}"
def trigger_start(evt):
def fn(signum, frame):
logger.info("Signal received, starting simulation...")
evt.set()
return fn
def main(args: Args):
verify_cert_availability(args.cert_path, args.masks, args.number_of_connections)
stop_event = multiprocessing.Event()
start_event = multiprocessing.Event()
if not args.wait_for_sig:
start_event.set()
signal.signal(signal.SIGUSR1, trigger_start(start_event))
processes = [multiprocessing.Process(target=process, args=(args, mask, start_event, stop_event))
for mask in args.masks]
try:
for p in processes:
p.start()
time.sleep(1)
logger.info(f"Started {len(processes)} processes")
if args.wait_for_sig:
logger.info("Waiting for SIGUSR1...")
while True:
time.sleep(100)
except KeyboardInterrupt:
logger.warn("Stopping all processes...")
stop_event.set()
start_event.set()
[p.join() for p in processes]

View File

@@ -0,0 +1,119 @@
from dataclasses import dataclass
from typing import List
import argparse
import random
import json
import re
import os
TEMPLATE_LOCATION = "./data/message_templates.json"
@dataclass
class Args:
number_of_connections: int
masks: List[str]
ca_path: str
cert_path: str
msg_size: int
msg_interval: int
wait_for_sig: bool
server_proto: str = "ws"
server_address: str = "localhost"
server_port: int = 50001
@property
def server(self):
return f"{self.server_proto}://{self.server_address}:{self.server_port}"
def parse_msg_size(input: str) -> int:
match = re.match(r"^(\d+)([kKmM]?)$", input)
if match is None:
raise ValueError(f"Unable to parse message size \"{input}\"")
num, prefix = match.groups()
num = int(num)
if prefix and prefix in "kK":
num *= 1000
elif prefix and prefix in "mM":
num *= 1000000
return num
def parse_args():
parser = argparse.ArgumentParser(
description="Used to simulate multiple clients that connect to a single server.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-s", "--server", metavar="ADDRESS", required=True,
default="ws://localhost:50001",
help="server address")
parser.add_argument("-N", "--number-of-connections", metavar="NUMBER", type=int,
default=1,
help="number of concurrent connections per thread pool")
parser.add_argument("-M", "--mac-mask", metavar="XX:XX:XX:XX:XX:XX", action="append",
default=[],
help="the mask determines what MAC addresses will be used by clients "
"in the thread pool. Specifying multiple masks will increase "
"the number of thread pools used.")
parser.add_argument("-a", "--ca-cert", metavar="CERT",
default="./certs/ca/ca.crt",
help="path to CA certificate")
parser.add_argument("-c", "--client-certs-path", metavar="PATH",
default="./certs/client",
help="path to client certificates directory")
parser.add_argument("-t", "--msg-interval", metavar="SECONDS", type=int,
default=10,
help="time between client messages to gw")
parser.add_argument("-p", "--payload-size", metavar="SIZE", type=str,
default="1k",
help="size of each client message")
parser.add_argument("-w", "--wait-for-signal", action="store_true",
help="wait for SIGUSR1 before running simulation")
parsed_args = parser.parse_args()
args = Args(number_of_connections=parsed_args.number_of_connections,
masks=parsed_args.mac_mask,
ca_path=parsed_args.ca_cert,
cert_path=parsed_args.client_certs_path,
msg_interval=parsed_args.msg_interval,
msg_size=parse_msg_size(parsed_args.payload_size),
wait_for_sig=parsed_args.wait_for_signal)
if len(args.masks) == 0:
args.masks.append("XX:XX:XX:XX:XX:XX")
# PROTO :// ADDRESS : PORT
match = re.match(r"(?:(wss?)://)?([\d\w\.]+):?(\d+)?", parsed_args.server)
if match is None:
raise ValueError(f"Unable to parse server address {parsed_args.server}")
proto, addr, port = match.groups()
if proto is not None:
args.server_proto = proto
if addr is not None:
args.server_address = addr
if port is not None:
args.server_port = port
return args
def rand_mac(mask="02:xx:xx:xx:xx:xx"):
return ''.join([n.lower().replace('x', f'{random.randint(0, 15):x}') for n in mask])
def get_msg_templates():
with open(TEMPLATE_LOCATION, "r") as templates:
return json.loads(templates.read())
def gen_certificates(mask: str, count=int):
cwd = os.getcwd()
os.chdir("../cert_generator")
try:
rc = os.system(f"./generate_certs.sh -c {count} -m \"{mask}\" -o")
assert rc == 0, "Generating certificates failed"
finally:
os.chdir(cwd)

7
utils/kafka_producer/.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
bin/
lib/
lib64
include/
share/
*.cfg
__pycache__

View File

@@ -0,0 +1,33 @@
{
"add_group": {
"type": "infrastructure_group_create",
"infra_group_id": "key",
"infra_name": "name",
"infra_shard_id": 0,
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
},
"del_group": {
"type": "infrastructure_group_delete",
"infra_group_id": "key",
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
},
"add_to_group": {
"type": "infrastructure_group_device_add",
"infra_group_id": "key",
"infra_group_infra_devices": [],
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
},
"del_from_group": {
"type": "infrastructure_group_device_del",
"infra_group_id": "key",
"infra_group_infra_devices": [],
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
},
"message_device": {
"type": "infrastructure_group_device_message",
"infra_group_id": "key",
"mac": "mac",
"msg": {},
"uuid": "290d06b6-8eba-11ee-8005-aabbccddeeff"
}
}

23
utils/kafka_producer/main.py Executable file
View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python3
from src.cli_parser import parse_args, Args
from src.producer import Producer
from src.log import logger
def main(args: Args):
producer = Producer(args.db, args.topic)
if args.add_groups or args.del_groups:
producer.handle_group_creation(args.add_groups, args.del_groups)
if args.assign_to_group or args.remove_from_group:
producer.handle_device_assignment(args.assign_to_group, args.remove_from_group)
if args.message:
producer.handle_device_messages(args.message, args.group_id, args.send_to_macs,
args.count, args.time_to_send_s, args.interval_s)
if __name__ == "__main__":
try:
args = parse_args()
main(args)
except KeyboardInterrupt:
logger.warn("exiting...")

View File

@@ -0,0 +1 @@
kafka-python==2.0.2

View File

View File

@@ -0,0 +1,116 @@
from .utils import MacRange, Args
import argparse
import json
def time(input: str) -> float:
"""
Expects a string in the format:
*number*
*number*s
*number*m
*number*h
Returns the number of seconds as an integer. Example:
100 -> 100
100s -> 100
5m -> 300
2h -> 7200
"""
if all(char not in input for char in "smh"):
return float(input)
n, t = float(input[:-1]), input[-1:]
if t == "h":
return n * 60 * 60
if t == "m":
return n * 60
return n
def parse_args():
parser = argparse.ArgumentParser(description="Creates entries in kafka.")
parser.add_argument("-g", "--new-group", metavar=("GROUP-ID", "SHARD-ID", "NAME"),
nargs=3, action="append",
help="create a new group")
parser.add_argument("-G", "--rm-group", metavar=("GROUP-ID"),
nargs=1, action="append",
help="delete an existing group")
parser.add_argument("-d", "--assign-to-group", metavar=("GROUP-ID", "MAC-RANGE"),
nargs=2, action="append",
help="add a range of mac addrs to a group")
parser.add_argument("-D", "--remove-from-group", metavar=("GROUP-ID", "MAC-RANGE"),
nargs=2, action="append",
help="remove mac addrs from a group")
parser.add_argument("-T", "--topic", default="CnC",
help="kafka topic (default: \"CnC\")")
parser.add_argument("-s", "--bootstrap-server", metavar="ADDRESS", default="172.20.10.136:9092",
help="kafka address (default: \"172.20.10.136:9092\")")
parser.add_argument("-m", "--send-message", metavar="JSON", type=str,
help="this message will be sent down from the GW to all devices "
"specified in the --send-to-mac")
parser.add_argument("-c", "--send-count", metavar="COUNT", type=int,
help="how many messages will be sent (per mac address, default: 1)")
parser.add_argument("-t", "--send-for", metavar="TIME", type=time,
help="how long to send the messages for")
parser.add_argument("-i", "--send-interval", metavar="INTERVAL", type=time, default="1",
help="time between messages (default: \"1.0s\")")
parser.add_argument("-p", "--send-to-group", metavar="GROUP-ID", type=str)
parser.add_argument("-r", "--send-to-mac", metavar="MAC-RANGE", type=MacRange,
help="range of mac addrs that will be receiving the messages")
parsed_args = parser.parse_args()
if parsed_args.send_message is not None and (
parsed_args.send_to_group is None or
parsed_args.send_to_mac is None
):
parser.error("--send-message requires --send-to-group and --send-to-mac")
message = None
if parsed_args.send_message is not None:
try:
message = json.loads(parsed_args.send_message)
except json.JSONDecodeError:
parser.error("--send-message must be in JSON format")
count = 1
if parsed_args.send_count is not None:
count = parsed_args.send_count
elif parsed_args.send_for is not None:
count = 0
args = Args(
[], [], [], [],
topic=parsed_args.topic,
db=parsed_args.bootstrap_server,
message=message,
count=count,
time_to_send_s=parsed_args.send_for,
interval_s=parsed_args.send_interval,
group_id=parsed_args.send_to_group,
send_to_macs=parsed_args.send_to_mac,
)
if parsed_args.new_group is not None:
for group, shard, name in parsed_args.new_group:
try:
args.add_groups.append((group, int(shard), name))
except ValueError:
parser.error(f"--new-group: failed to parse shard id \"{shard}\"")
if parsed_args.rm_group is not None:
for (group,) in parsed_args.rm_group:
args.del_groups.append(group)
if parsed_args.assign_to_group is not None:
try:
for group, mac in parsed_args.assign_to_group:
args.assign_to_group.append((group, MacRange(mac)))
except ValueError:
parser.error(f"--assign-to-group: failed to parse MAC range \"{mac}\"")
if parsed_args.remove_from_group is not None:
try:
for group, mac in parsed_args.remove_from_group:
args.remove_from_group.append((group, MacRange(mac)))
except ValueError:
parser.error(f"--remove-from-group: failed to parse MAC range \"{mac}\"")
return args

View File

@@ -0,0 +1,37 @@
import logging
class ColoredFormatter(logging.Formatter):
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
white = "\x1b[37;20m"
cyan = "\x1b[36;20m"
reset = "\x1b[0m"
format = "{asctime}|{levelname}|{threadName}|{funcName}:{lineno}\t{message}"
FORMATS = {
logging.DEBUG: grey + format + reset,
logging.INFO: cyan + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, style="{")
return formatter.format(record)
logging.basicConfig(level=logging.INFO, datefmt="%H:%M:%S")
console = logging.StreamHandler()
console.setLevel(logging.NOTSET)
console.setFormatter(ColoredFormatter())
logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(console)

View File

@@ -0,0 +1,74 @@
from .utils import Message, MacRange
from .log import logger
from typing import List, Tuple
import kafka
import time
import uuid
import sys
class Producer:
def __init__(self, db: str, topic: str) -> None:
self.db = db
self.conn = None
self.topic = topic
self.message = Message()
def __enter__(self) -> kafka.KafkaProducer:
return self.connect()
def __exit__(self, exc_type, exc_val, exc_tb):
self.disconnect()
def connect(self) -> kafka.KafkaProducer:
if self.conn is None:
self.conn = kafka.KafkaProducer(bootstrap_servers=self.db, client_id="producer")
logger.info("connected to kafka")
else:
logger.info("already connected to kafka")
return self.conn
def disconnect(self) -> None:
if self.conn is None:
return
self.conn.close()
logger.info("disconnected from kafka")
self.conn = None
def handle_group_creation(self, create: List[Tuple[str, int, str]], delete: List[str]) -> None:
with self as conn:
for group, shard_id, name in create:
conn.send(self.topic, self.message.group_create(group, shard_id, name),
bytes(group, encoding="utf-8"))
for group in delete:
conn.send(self.topic, self.message.group_delete(group),
bytes(group, encoding="utf-8"))
def handle_device_assignment(self, add: List[Tuple[str, MacRange]], remove: List[Tuple[str, MacRange]]) -> None:
with self as conn:
for group, mac_range in add:
logger.debug(f"{group = }, {mac_range = }")
conn.send(self.topic, self.message.add_dev_to_group(group, mac_range),
bytes(group, encoding="utf-8"))
for group, mac_range in remove:
conn.send(self.topic, self.message.remove_dev_from_group(group, mac_range),
bytes(group, encoding="utf-8"))
def handle_device_messages(self, message: dict, group: str, mac_range: MacRange,
count: int, time_s: int, interval_s: int) -> None:
if not time_s:
end = sys.maxsize
else:
end = time.time() + time_s
if not count:
count = sys.maxsize
with self as conn:
for seq in range(count):
for mac in mac_range:
conn.send(self.topic, self.message.to_device(group, mac, message, seq),
bytes(group, encoding="utf-8"))
#time.sleep(interval_s)
#if time.time() > end:
# break

View File

@@ -0,0 +1,152 @@
from dataclasses import dataclass
from typing import List, Tuple
from typing import Tuple
import copy
import json
import uuid
class MacRange:
"""
Return an object that produces a sequence of MAC addresses from
START (inclusive) to END (inclusive). START and END are exctracted
from the input string if it is in the format
"11:22:AA:BB:00:00-11:22:AA:BB:00:05" (where START=11:22:AA:BB:00:00,
END=11:22:AA:BB:00:05, and the total amount of MACs in the range is 6).
Examples (all of these are identical):
00:00:00:00:XX:XX
00:00:00:00:00:00-00:00:00:00:FF:FF
00:00:00:00:00:00^65536
Raises ValueError
"""
def __init__(self, input: str = "XX:XX:XX:XX:XX:XX") -> None:
self.__base_as_num, self.__len = self.__parse_input(input.upper())
self.__idx = 0
def __iter__(self):
return self
def __next__(self) -> str:
if self.__idx >= len(self):
self.__idx = 0
raise StopIteration()
mac = self.num2mac(self.__base_as_num + self.__idx)
self.__idx += 1
return mac
def __len__(self) -> int:
return self.__len
def __str__(self) -> str:
return f"MacRange[start={self.base}, " \
f"end={self.num2mac(self.__base_as_num + len(self) - 1)}]"
def __repr__(self) -> str:
return f"MacRange('{self.base}^{len(self)}')"
@property
def base(self) -> str:
return self.num2mac(self.__base_as_num)
@staticmethod
def mac2num(mac: str) -> int:
return int(mac.replace(":", ""), base=16)
@staticmethod
def num2mac(mac: int) -> str:
hex = f"{mac:012X}"
return ":".join([a+b for a, b in zip(hex[::2], hex[1::2])])
def __parse_input(self, input: str) -> Tuple[int, int]:
if "X" in input:
string = f"{input.replace('X', '0')}-{input.replace('X', 'F')}"
else:
string = input
if "-" in string:
start, end = string.split("-")
start, end = self.mac2num(start), self.mac2num(end)
if start > end:
raise ValueError(f"Invalid MAC range {start}-{end}")
return start, end - start + 1
if "^" in string:
base, count = string.split("^")
return self.mac2num(base), int(count)
return self.mac2num(input), 1
class Message:
TEMPLATE_FILE = "./data/message_template.json"
GROUP_ADD = "add_group"
GROUP_DEL = "del_group"
DEV_TO_GROUP = "add_to_group"
DEV_FROM_GROUP = "del_from_group"
TO_DEVICE = "message_device"
GROUP_ID = "infra_group_id"
GROUP_NAME = "infra_name"
SHARD_ID = "infra_shard_id"
DEV_LIST = "infra_group_infra_devices"
MAC = "mac"
DATA = "msg"
MSG_UUID = "uuid"
def __init__(self) -> None:
with open(self.TEMPLATE_FILE) as f:
self.templates = json.loads(f.read())
def group_create(self, id: str, shard_id: int, name: str) -> bytes:
msg = copy.copy(self.templates[self.GROUP_ADD])
msg[self.GROUP_ID] = id
msg[self.SHARD_ID] = shard_id
msg[self.GROUP_NAME] = name
msg[self.MSG_UUID] = str(uuid.uuid1())
return json.dumps(msg).encode('utf-8')
def group_delete(self, id: str) -> bytes:
msg = copy.copy(self.templates[self.GROUP_DEL])
msg[self.GROUP_ID] = id
msg[self.MSG_UUID] = str(uuid.uuid1())
return json.dumps(msg).encode('utf-8')
def add_dev_to_group(self, id: str, mac_range: MacRange) -> bytes:
msg = copy.copy(self.templates[self.DEV_TO_GROUP])
msg[self.GROUP_ID] = id
msg[self.DEV_LIST] = list(mac_range)
msg[self.MSG_UUID] = str(uuid.uuid1())
return json.dumps(msg).encode('utf-8')
def remove_dev_from_group(self, id: str, mac_range: MacRange) -> bytes:
msg = copy.copy(self.templates[self.DEV_FROM_GROUP])
msg[self.GROUP_ID] = id
msg[self.DEV_LIST] = list(mac_range)
msg[self.MSG_UUID] = str(uuid.uuid1())
return json.dumps(msg).encode('utf-8')
def to_device(self, id: str, mac: str, data, sequence: int = 0):
msg = copy.copy(self.templates[self.TO_DEVICE])
msg[self.GROUP_ID] = id
msg[self.MAC] = mac
if type(data) is dict:
msg[self.DATA] = data
else:
msg[self.DATA] = {"data": data}
msg[self.MSG_UUID] = str(uuid.uuid1(node=MacRange.mac2num(mac), clock_seq=sequence))
return json.dumps(msg).encode('utf-8')
@dataclass
class Args:
add_groups: List[Tuple[str, int, str]]
del_groups: List[str]
assign_to_group: List[Tuple[str, MacRange]]
remove_from_group: List[Tuple[str, MacRange]]
topic: str
db: str
message: dict
count: int
time_to_send_s: float
interval_s: float
group_id: int
send_to_macs: MacRange