From 14519559686525121352de3e3c9b9cacc4242038 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Fri, 27 Jun 2025 00:46:44 +0100 Subject: Initial commit --- src/config.rs | 59 ++++++++++++++++++++ src/main.rs | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 src/config.rs create mode 100644 src/main.rs (limited to 'src') diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..1190b9a --- /dev/null +++ b/src/config.rs @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: 2025 Tomasz Kramkowski +// SPDX-License-Identifier: GPL-3.0-or-later + +use std::{collections::HashMap, fs, path::Path, process, time::Duration}; + +use rumqttc::{Client, Connection, MqttOptions}; +use serde::Deserialize; + +use crate::PROGRAM; + +#[derive(Deserialize, Debug)] +pub struct Credentials { + pub username: String, + pub password: String, +} + +fn default_host() -> String { + "localhost".to_string() +} + +fn default_port() -> u16 { + 1883 +} + +fn default_id() -> String { + PROGRAM.to_string() +} + +#[derive(Deserialize, Debug)] +pub struct Config { + #[serde(default = "default_host")] + pub host: String, + #[serde(default = "default_port")] + pub port: u16, + pub credentials: Option, + #[serde(default = "default_id")] + pub id: String, + // TODO: Figure out a way to allow arbitrary unix paths (arbitrary + // non-unicode) without base64 + pub routes: HashMap>>, +} + +impl Config { + pub fn mqtt_client(&self) -> (Client, Connection) { + let client_id = format!("{}_{}", self.id, process::id()); + let mut options = MqttOptions::new(client_id, &self.host, self.port); + if let Some(credentials) = &self.credentials { + options.set_credentials(&credentials.username, &credentials.password); + } + options.set_keep_alive(Duration::from_secs(5)); + options.set_max_packet_size(10 * 1024 * 1024, 10 * 1024 * 1024); + Client::new(options, 10) + } +} + +pub fn load>(path: P) -> anyhow::Result { + let config = fs::read_to_string(&path)?; + Ok(toml::from_str(&config)?) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..14ccb7a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,174 @@ +// SPDX-FileCopyrightText: 2025 Tomasz Kramkowski +// SPDX-License-Identifier: GPL-3.0-or-later + +// TODO: Log levels + +use std::{ + io::Write, + path::PathBuf, + process::{Command, Stdio}, +}; + +use anyhow::Context; +use rumqttc::{Event::Incoming, Packet, Publish, QoS}; + +mod config; + +const PROGRAM: &str = "mqttr"; + +fn run(program: &[String], message: &Publish) -> anyhow::Result<()> { + // TODO: Async + // TODO: Set environment variables + let mut proc = Command::new(&program[0]) + .args(&program[1..]) + .arg(&message.topic) + .stdin(Stdio::piped()) + .spawn()?; + let stdin = proc.stdin.as_mut().context("No stdin")?; + stdin.write_all(&message.payload)?; + println!("{}", proc.wait()?); + Ok(()) +} + +fn topic_match(filter: &str, topic: &str) -> bool { + // TODO: Should probably just be a panic or prevented using types + if filter.is_empty() || topic.is_empty() { + return false; + } + if topic.starts_with('$') && (filter.starts_with('+') || filter.starts_with('#')) { + return false; + } + + // zip_longest would be nice + let mut topic = topic.split('/'); + let mut filter = filter.split('/'); + loop { + let topic_level = topic.next(); + return match filter.next() { + Some("#") => filter.next().is_none(), + Some("+") => { + if topic_level.is_none() { + false + } else { + continue; + } + } + Some(filter_level) => match topic_level { + Some(topic_level) if topic_level == filter_level => { + continue; + } + _ => false, + }, + None => topic_level.is_none(), + }; + } +} + +fn main() -> anyhow::Result<()> { + let mut conf_path: PathBuf = option_env!("SYSCONFDIR").unwrap_or("/usr/local/etc").into(); + conf_path.push(format!("{PROGRAM}.toml")); + let conf = config::load(&conf_path) + .with_context(|| format!("Failed to load config: {:?}", &conf_path))?; + let (client, mut connection) = conf.mqtt_client(); + for topic in conf.routes.keys() { + // TODO: Configurable subscription QoS + if let Err(e) = client.subscribe(topic, QoS::AtMostOnce) { + eprintln!("warning: Failed to subscribe to '{topic}': {e:?}"); + } + } + for notification in connection.iter() { + match notification? { + Incoming(Packet::Publish(p)) => { + for (topic, programs) in conf.routes.iter() { + if topic_match(&topic, &p.topic) { + for program in programs { + if let Err(e) = run(program, &p) { + eprintln!("error: Failed to run {program:?}: {e:?}"); + } + } + } + } + } + _ => (), + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::topic_match; + + #[test] + fn topic_match_basic() { + assert!(topic_match("foo/bar/baz", "foo/bar/baz")); + assert!(!topic_match("foo/bar/baz", "foo/bar/qux")); + assert!(!topic_match("foo/bar", "foo/bar/baz")); + assert!(!topic_match("foo/bar/baz", "foo/bar")); + } + + #[test] + fn topic_match_wildcard_hash() { + assert!(topic_match("foo/bar/baz/#", "foo/bar/baz")); + assert!(topic_match("foo/bar/baz/#", "foo/bar/baz/qux")); + assert!(topic_match("foo/bar/baz/#", "foo/bar/baz/qux/quux")); + assert!(topic_match("#", "foo/bar/baz")); + assert!(topic_match("#", "foo")); + assert!(topic_match("#", "/")); + assert!(topic_match("#", "/foo")); + assert!(!topic_match("foo/bar/#", "foo/baz/bar")); + assert!(!topic_match("foo/bar/#", "foo")); + } + + #[test] + fn topic_match_wildcard_plus() { + assert!(topic_match("foo/bar/+", "foo/bar/baz")); + assert!(topic_match("foo/bar/+", "foo/bar/qux")); + assert!(!topic_match("foo/bar/+", "foo/bar/baz/qux")); + assert!(topic_match("foo/+", "foo/")); + assert!(!topic_match("foo/+", "foo")); + assert!(topic_match("+", "foo")); + assert!(topic_match("+/bar/#", "foo/bar/baz/qux")); + assert!(topic_match("+/bar/#", "qux/bar")); + assert!(topic_match("foo/+/baz", "foo/bar/baz")); + assert!(topic_match("foo/+/baz", "foo/qux/baz")); + assert!(!topic_match("foo/+/baz", "foo/bar/qux")); + assert!(topic_match("+/+", "/foo")); + assert!(topic_match("/+", "/foo")); + assert!(!topic_match("+", "/foo")); + } + + #[test] + fn topic_match_dollar() { + assert!(!topic_match("#", "$foo/bar")); + assert!(!topic_match("+/bar/baz", "$foo/bar/baz")); + assert!(topic_match("$foo/#", "$foo/bar")); + assert!(topic_match("$foo/#", "$foo/bar/baz")); + assert!(topic_match("$foo/#", "$foo")); + assert!(topic_match("$foo/bar/+", "$foo/bar/baz")); + assert!(!topic_match("$foo/#", "foo/bar")); + } + + #[test] + fn topic_match_edge_cases() { + assert!(!topic_match("foo", "FOO")); + assert!(topic_match("foo bar", "foo bar")); + assert!(!topic_match("foo bar", "foo bar")); + assert!(!topic_match("foo bar", "foo bar")); + assert!(!topic_match("/foo", "foo")); + assert!(!topic_match("foo", "/foo")); + assert!(topic_match("foo//bar", "foo//bar")); + assert!(!topic_match("foo/bar", "foo//bar")); + assert!(!topic_match("foo//bar", "foo/bar")); + assert!(!topic_match("foo//baz", "foo/bar/baz")); + assert!(!topic_match("foo/bar/baz", "foo//baz")); + assert!(topic_match("foo/+/baz", "foo//baz")); + assert!(topic_match("/", "/")); + assert!(!topic_match("/", "foo")); + assert!(!topic_match("foo", "/")); + assert!(!topic_match("", "")); + assert!(!topic_match("+", "")); + assert!(!topic_match("#", "")); + assert!(!topic_match("foo/#/baz", "foo/bar/baz")); + } +} -- cgit v1.2.3-70-g09d2