From 8fe78ad9f932513f27b6f97a705cdfc1ef6a16b9 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Sat, 7 Jun 2025 15:53:32 +0100 Subject: init commit --- src/config.rs | 69 ++++++++ src/main.rs | 515 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 584 insertions(+) create mode 100644 src/config.rs create mode 100644 src/main.rs (limited to 'src') diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..9000f01 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,69 @@ +use std::{ + fs, io, + path::{Path, PathBuf}, + process, + time::Duration, +}; + +use rumqttc::{Client, Connection, MqttOptions}; +use serde::Deserialize; + +use crate::PROGRAM; + +#[derive(Deserialize)] +pub struct Credentials { + pub username: String, + pub password: String, +} + +fn default_root() -> PathBuf { + option_env!("DEFAULT_ROOT") + .unwrap_or(&format!("/var/run/{PROGRAM}")) + .into() +} + +fn default_host() -> String { + "localhost".to_string() +} + +fn default_port() -> u16 { + 1883 +} + +fn default_id() -> String { + PROGRAM.to_string() +} + +#[derive(Deserialize)] +pub struct Config { + #[serde(default = "default_root")] + pub root: PathBuf, + #[serde(default = "default_host")] + pub host: String, + #[serde(default = "default_port")] + pub port: u16, + pub credentials: Option, + #[serde(default = "default_id")] + pub id: String, +} + +impl Config { + pub fn mqtt_client(&self) -> (Client, Connection) { + let client_id = format!("{}_{}", self.id, process::id()); + let mut options = MqttOptions::new(client_id, &self.host, self.port); + if let Some(credentials) = &self.credentials { + options.set_credentials(&credentials.username, &credentials.password); + } + options.set_keep_alive(Duration::from_secs(5)); + Client::new(options, 10) + } +} + +pub fn load>(path: P) -> anyhow::Result { + let config = match fs::read_to_string(&path) { + Ok(s) => Ok(s), + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(String::new()), + Err(e) => Err(e), + }?; + Ok(toml::from_str(&config)?) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..6ccafba --- /dev/null +++ b/src/main.rs @@ -0,0 +1,515 @@ +use std::{ + collections::HashMap, + fs, + io::Write, + os::unix::fs::PermissionsExt, + path::{Path, PathBuf}, + process::{Command, Stdio}, +}; + +use anyhow::Context; +use rumqttc::{Event::Incoming, Packet::Publish, QoS}; + +mod config; + +const PROGRAM: &str = "mqttt"; + +struct Node { + children: HashMap, + executables: Box<[PathBuf]>, +} + +fn is_executable>(path: P) -> std::io::Result { + Ok(path.as_ref().metadata()?.permissions().mode() & 0o111 != 0) +} + +impl Node { + fn build>(path: P) -> std::io::Result { + let mut children = HashMap::new(); + let mut executables = Vec::new(); + + for entry in fs::read_dir(&path)? { + let entry = entry?; + let path = entry.path(); + let name = match entry.file_name().into_string() { + Ok(name) => name, + Err(_) => { + eprintln!("warning: Path '{path:?}' is not valid UTF-8. Skipping..."); + continue; + } + }; + + if path.is_dir() { + let child = Node::build(&path)?; + children.insert(name, child); + } else if path.is_file() && is_executable(&path)? { + executables.push(path); + } + } + + Ok(Node { + children, + executables: executables.into_boxed_slice(), + }) + } + + fn traverse(&self, prefix: &str, f: &mut F) { + f(prefix, &self); + for (name, child) in &self.children { + let sep = if prefix != "" { "/" } else { "" }; + let name = if name != "#empty" { &name } else { "" }; + let path = format!("{prefix}{sep}{name}"); + child.traverse(&path, f); + } + } + + fn publish(&self, path: &str, f: &mut F) { + let path = if path == "" { None } else { Some(path) }; + let is_sys = path.is_some_and(|p| p.starts_with("$")); + self.publish_impl(path, f, is_sys); + } + + fn publish_impl(&self, path: Option<&str>, f: &mut F, is_sys: bool) { + let Some(path) = path else { + f(&self); + if let Some(child) = self.children.get("#") { + f(&child); + } + return; + }; + let (front, rest) = match path.split_once('/') { + Some((front, rest)) => (front, Some(rest)), + None => (path, None), + }; + if let Some(child) = self.children.get(front) { + child.publish_impl(rest, f, false); + } + if !is_sys { + if let Some(child) = self.children.get("+") { + child.publish_impl(rest, f, false); + } + if let Some(child) = self.children.get("#") { + child.publish_impl(None, f, false); + } + } + } +} + +fn main() -> anyhow::Result<()> { + let mut conf_path: PathBuf = option_env!("SYSCONFDIR").unwrap_or("/usr/local/etc").into(); + conf_path.push(format!("{PROGRAM}.toml")); + let conf = config::load(&conf_path) + .with_context(|| format!("Failed to load config: {:?}", &conf_path))?; + let root = Node::build(&conf.root).context("Failed to build tree")?; + let (client, mut connection) = conf.mqtt_client(); + root.traverse("", &mut |path, node| { + if node.executables.len() != 0 { + if let Err(e) = client.subscribe(path, QoS::AtMostOnce) { + eprintln!("warning: Failed to subscribe {path}: {e:?}"); + } else { + println!("Subscribed to {path}"); + } + } + }); + for notification in connection.iter() { + match notification? { + Incoming(Publish(p)) => root.publish(&p.topic, &mut |node| { + for e in &node.executables { + let mut proc = Command::new(e) + .args([&p.topic]) + .stdin(Stdio::piped()) + .spawn() + .unwrap(); + let stdin = proc.stdin.as_mut().unwrap(); + stdin.write_all(&p.payload).unwrap(); + println!("{}", proc.wait().unwrap()); + } + }), + _ => (), + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::Node; + use std::collections::HashMap; + use std::path::PathBuf; + + fn node(id_str: Option<&str>, children_data: Vec<(&str, Node)>) -> Node { + let executables = id_str + .map(|s| vec![PathBuf::from(s)].into_boxed_slice()) + .unwrap_or_default(); + + let children: HashMap = children_data + .into_iter() + .map(|(name, child_node)| (name.to_string(), child_node)) + .collect(); + + Node { + children, + executables, + } + } + + fn assert_publish_ids(root_node: &Node, publish_path: &str, expected_ids_str: &[&str]) { + let mut actual_ids: Vec = Vec::new(); + root_node.publish(publish_path, &mut |n| { + if let Some(id) = n.executables.get(0) { + actual_ids.push(id.clone()); + } + }); + + assert_eq!( + actual_ids.len(), + expected_ids_str.len(), + "Path '{}': Number of called IDs ({}) does not match expected ({}). Actual: {:?}, Expected: {:?}", + publish_path, + actual_ids.len(), + expected_ids_str.len(), + actual_ids, + expected_ids_str + ); + for (i, expected_id_str) in expected_ids_str.iter().enumerate() { + assert_eq!( + actual_ids[i], + PathBuf::from(expected_id_str), + "Path '{}': Called ID at index {} does not match. Actual: {:?}, Expected: {:?}", + publish_path, + i, + actual_ids[i], + expected_id_str + ); + } + } + + #[test] + fn topic_single_segment_exact() { + assert_publish_ids( + &node(None, vec![("a", node(Some("a_id"), vec![]))]), + "a", + &["a_id"], + ); + } + + #[test] + fn topic_multi_segment_exact() { + assert_publish_ids( + &node( + None, + vec![("a", node(None, vec![("b", node(Some("b_id"), vec![]))]))], + ), + "a/b", + &["b_id"], + ); + } + + #[test] + fn topic_single_segment_plus_wildcard() { + assert_publish_ids( + &node(None, vec![("+", node(Some("plus_id"), vec![]))]), + "unknown", + &["plus_id"], + ); + } + + #[test] + fn topic_single_segment_hash_literal_wildcard() { + assert_publish_ids( + &node(None, vec![("#", node(Some("hash_literal_id"), vec![]))]), + "unknown", + &["hash_literal_id"], + ); + } + + #[test] + fn topic_multi_segment_hash_literal_wildcard() { + assert_publish_ids( + &node(None, vec![("#", node(Some("hash_literal_id"), vec![]))]), + "unknown/unknown", + &["hash_literal_id"], + ); + } + + #[test] + fn topic_single_segment_all_match_types() { + assert_publish_ids( + &node( + None, + vec![ + ("data", node(Some("data_id"), vec![])), + ("+", node(Some("plus_id"), vec![])), + ("#", node(Some("hash_literal_id"), vec![])), + ], + ), + "data", + &["data_id", "plus_id", "hash_literal_id"], + ); + } + + #[test] + fn topic_multi_segment_all_match_types() { + assert_publish_ids( + &node( + None, + vec![ + ( + "data", + node( + None, + vec![ + ("data", node(Some("data_id"), vec![])), + ("+", node(Some("plus_id"), vec![])), + ], + ), + ), + ("#", node(Some("hash_literal_id"), vec![])), + ], + ), + "data/data", + &["data_id", "plus_id", "hash_literal_id"], + ); + } + + #[test] + fn topic_path_ends_triggers_base_case_hash_wildcard_child() { + let root = node( + None, + vec![( + "a", + node(Some("a_id"), vec![("#", node(Some("a_hash_id"), vec![]))]), + )], + ); + assert_publish_ids(&root, "a", &["a_id", "a_hash_id"]); + } + + #[test] + fn topic_no_match_for_segment() { + let root = node(None, vec![("known", node(Some("known_id"), vec![]))]); + assert_publish_ids(&root, "unknown", &[]); + } + + #[test] + fn topic_path_deeper_than_tree() { + let root = node( + None, + vec![("a", node(None, vec![("b", node(Some("b_id"), vec![]))]))], + ); + assert_publish_ids(&root, "a/b/c", &[]); + } + + #[test] + fn topic_trailing_slash_maps_to_empty_key_child() { + let root = node( + None, + vec![( + "a", + node(None, vec![("", node(Some("a_empty_key_id"), vec![]))]), + )], + ); + assert_publish_ids(&root, "a/", &["a_empty_key_id"]); + } + + #[test] + fn topic_multi_trailing_slash_maps_to_empty_key_child() { + let root = node( + None, + vec![( + "a", + node( + None, + vec![( + "b", + node(None, vec![("", node(Some("b_empty_key_id"), vec![]))]), + )], + ), + )], + ); + assert_publish_ids(&root, "a/b/", &["b_empty_key_id"]); + } + + #[test] + fn topic_trailing_slash_plus_wildcard_for_empty_key() { + // a/ -> a/(+ -> "") + let root = node( + None, + vec![( + "a", + node( + None, + vec![("+", node(Some("a_plus_for_empty_key"), vec![]))], + ), + )], + ); + assert_publish_ids(&root, "a/", &["a_plus_for_empty_key"]); + } + + #[test] + fn topic_a_double_slash_b() { + let root = node( + None, + vec![( + "a", + node( + None, + vec![("", node(None, vec![("b", node(Some("b_id"), vec![]))]))], + ), + )], + ); + assert_publish_ids(&root, "a//b", &["b_id"]); + } + + #[test] + fn topic_a_triple_slash_b() { + let root = node( + None, + vec![( + "a", + node( + None, + vec![( + "", + node( + None, + vec![("", node(None, vec![("b", node(Some("b_id"), vec![]))]))], + ), + )], + ), + )], + ); + assert_publish_ids(&root, "a///b", &["b_id"]); + } + + #[test] + fn topic_root_single_slash() { + let root = node( + None, + vec![( + "", + node(None, vec![("", node(Some("via_single_slash"), vec![]))]), + )], + ); + assert_publish_ids(&root, "/", &["via_single_slash"]); + } + + #[test] + fn topic_root_double_slash() { + let root = node( + None, + vec![( + "", + node( + None, + vec![( + "", + node(None, vec![("", node(Some("via_double_slash"), vec![]))]), + )], + ), + )], + ); + assert_publish_ids(&root, "//", &["via_double_slash"]); + } + + #[test] + fn topic_leading_double_slash_a() { + let root = node( + None, + vec![( + "", + node( + None, + vec![("", node(None, vec![("a", node(Some("a_id"), vec![]))]))], + ), + )], + ); + assert_publish_ids(&root, "//a", &["a_id"]); + } + + #[test] + fn topic_trailing_double_slash_a() { + let root = node( + None, + vec![( + "a", + node( + None, + vec![("", node(None, vec![("", node(Some("empty_id"), vec![]))]))], + ), + )], + ); + assert_publish_ids(&root, "a//", &["empty_id"]); + } + + #[test] + fn topic_a_double_slash_b_with_plus_wildcards() { + let root = node( + None, + vec![( + "a", + node( + None, + vec![("+", node(None, vec![("b", node(Some("b_id"), vec![]))]))], + ), + )], + ); + assert_publish_ids(&root, "a//b", &["b_id"]); + } + + #[test] + fn topic_a_trailing_slash_with_base_case_hash_on_empty_key_node() { + let root = node( + None, + vec![( + "a", + node( + None, + vec![( + "", + node( + Some("a_empty_key_id"), + vec![("#", node(Some("a_empty_key_hash_id"), vec![]))], + ), + )], + ), + )], + ); + assert_publish_ids(&root, "a/", &["a_empty_key_id", "a_empty_key_hash_id"]); + } + + #[test] + fn sys_topic_only_prefixed() { + assert_publish_ids( + &node( + None, + vec![ + ( + "$SYS", + node( + None, + vec![ + ("foo", node(Some("sys_foo_id"), vec![])), + ("#", node(Some("sys_hash_id"), vec![])), + ("+", node(Some("sys_plus_id"), vec![])), + ], + ), + ), + ("#", node(Some("hash_id"), vec![])), + ( + "+", + node( + None, + vec![ + ("foo", node(Some("plus_food_id"), vec![])), + ("#", node(Some("plus_hash_id"), vec![])), + ("+", node(Some("plus_plus_id"), vec![])), + ], + ), + ), + ], + ), + "$SYS/foo", + &["sys_foo_id", "sys_plus_id", "sys_hash_id"], + ); + } +} -- cgit v1.2.3-70-g09d2