// SPDX-FileCopyrightText: 2025 Tomasz Kramkowski // SPDX-License-Identifier: GPL-3.0-or-later use std::{ collections::HashMap, fmt, fs::File, io::Read, os::unix::fs::PermissionsExt, path::Path, process, time::Duration, }; use anyhow::bail; use rumqttc::{AsyncClient, EventLoop, MqttOptions, QoS}; use serde::{ de::{self, Visitor}, Deserialize, Deserializer, }; use crate::PROGRAM; #[derive(Deserialize, Debug)] pub struct Credentials { pub username: String, pub password: String, } fn default_host() -> String { "localhost".to_string() } fn default_port() -> u16 { 1883 } fn default_qos() -> QoS { QoS::ExactlyOnce } fn default_id() -> String { PROGRAM.to_string() } fn default_timeout() -> Duration { Duration::from_secs(60) } #[allow(clippy::enum_variant_names)] #[derive(Deserialize, Debug)] #[serde(remote = "QoS", rename_all = "kebab-case")] #[repr(u8)] pub enum QoSDef { AtMostOnce = 0, AtLeastOnce = 1, ExactlyOnce = 2, } pub fn deserialize_qos_opt<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { #[derive(Deserialize)] struct Helper(#[serde(with = "QoSDef")] QoS); let helper = Option::deserialize(deserializer)?; Ok(helper.map(|Helper(external)| external)) } #[derive(Debug, PartialEq, Clone)] pub struct Program { // TODO: Figure out a way to allow arbitrary unix paths (arbitrary // non-unicode) without base64 pub command: Box<[String]>, pub timeout: Option, } impl<'de> Deserialize<'de> for Program { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct VecOrProgram; impl<'de> Visitor<'de> for VecOrProgram { type Value = Program; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("map or seq") } fn visit_seq(self, seq: A) -> Result where A: serde::de::SeqAccess<'de>, { let vec: Box<[String]> = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; Ok(Program { command: vec, timeout: None, }) } fn visit_map(self, map: A) -> Result where A: serde::de::MapAccess<'de>, { #[derive(Deserialize)] struct Helper { command: Box<[String]>, #[serde(default, deserialize_with = "deserialize_timeout_opt")] timeout: Option, } let helper: Helper = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; Ok(Program { command: helper.command, timeout: helper.timeout, }) } } deserializer.deserialize_any(VecOrProgram) } } #[derive(Debug)] pub struct Route { pub programs: Box<[Program]>, pub qos: Option, } impl<'de> Deserialize<'de> for Route { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct VecOrRoute; impl<'de> Visitor<'de> for VecOrRoute { type Value = Route; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("map or seq") } fn visit_seq(self, seq: A) -> Result where A: serde::de::SeqAccess<'de>, { let vec: Box<[Program]> = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; Ok(Route { programs: vec, qos: None, }) } fn visit_map(self, map: A) -> Result where A: serde::de::MapAccess<'de>, { #[derive(Deserialize)] struct Helper { programs: Box<[Program]>, #[serde(default, deserialize_with = "deserialize_qos_opt")] qos: Option, } let helper: Helper = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; Ok(Route { programs: helper.programs, qos: helper.qos, }) } } deserializer.deserialize_any(VecOrRoute) } } pub fn deserialize_timeout<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { struct DurationVisitor; impl<'de> de::Visitor<'de> for DurationVisitor { type Value = Duration; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a positive number") } fn visit_i64(self, v: i64) -> Result where E: de::Error, { if v < 0 { return Err(de::Error::invalid_value( de::Unexpected::Signed(v), &"a non-negative number", )); } if v == 0 { Ok(Duration::MAX) } else { Ok(Duration::from_secs(v as u64)) } } fn visit_f64(self, v: f64) -> Result where E: de::Error, { if v < 0.0 { return Err(de::Error::invalid_value( de::Unexpected::Float(v), &"a non-negative number", )); } if v == 0.0 { Ok(Duration::MAX) } else { Ok(Duration::from_secs_f64(v)) } } } deserializer.deserialize_any(DurationVisitor) } pub fn deserialize_timeout_opt<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { #[derive(Deserialize)] struct Helper(#[serde(deserialize_with = "deserialize_timeout")] Duration); let helper = Option::deserialize(deserializer)?; Ok(helper.map(|Helper(external)| external)) } #[derive(Deserialize, Debug)] pub struct Config { #[serde(default = "default_host")] pub host: String, #[serde(default = "default_port")] pub port: u16, #[serde(with = "QoSDef", default = "default_qos")] pub qos: QoS, #[serde(default = "default_timeout", deserialize_with = "deserialize_timeout")] pub timeout: Duration, pub credentials: Option, #[serde(default = "default_id")] pub id: String, pub routes: HashMap, } impl Config { pub fn mqtt_client(&self) -> (AsyncClient, EventLoop) { let client_id = format!("{}_{}", self.id, process::id()); let mut options = MqttOptions::new(client_id, &self.host, self.port); if let Some(credentials) = &self.credentials { options.set_credentials(&credentials.username, &credentials.password); } // TODO: Make configurable options.set_keep_alive(Duration::from_secs(5)); options.set_max_packet_size(10 * 1024 * 1024, 10 * 1024 * 1024); AsyncClient::new(options, 10) } } pub fn load>(path: P) -> anyhow::Result { let mut f = File::open(path)?; let mut config = String::new(); f.read_to_string(&mut config)?; let config: Config = toml::from_str(&config)?; if config.credentials.is_some() { let mode = f.metadata()?.permissions().mode(); if mode & 0o044 != 0o000 { bail!("Config file contains credentials while being group or world readable."); } } Ok(config) } #[cfg(test)] mod tests { use super::*; use rumqttc::QoS; use std::time::Duration; impl Program { fn new(command: Vec<&str>) -> Self { Program { command: command.into_iter().map(str::to_string).collect(), timeout: None, } } fn new_with_timeout(command: Vec<&str>, timeout: Duration) -> Self { Program { command: command.into_iter().map(str::to_string).collect(), timeout: Some(timeout), } } } #[test] fn load_full_config() { let toml_str = r#" host = "foo.bar.baz" port = 1234 qos = "at-most-once" id = "custom-id" timeout = 15.5 [credentials] username = "testuser" password = "testpassword" [routes] "topic/map" = { programs = [ ["/bin/program1"], ["/bin/program2", "arg"], { command = ["/bin/program3", "arg"]}, ], qos = "exactly-once" } "topic/seq" = [ ["/bin/program4", "arg"], { command = ["/bin/program5"], timeout = 1.2 }, ] "#; let config: Config = toml::from_str(toml_str).expect("Failed to parse full config"); assert_eq!(config.host, "foo.bar.baz"); assert_eq!(config.port, 1234); assert_eq!(config.qos, QoS::AtMostOnce); assert_eq!(config.id, "custom-id"); assert_eq!(config.timeout, Duration::from_secs_f64(15.5)); let creds = config.credentials.expect("Credentials should be present"); assert_eq!(creds.username, "testuser"); assert_eq!(creds.password, "testpassword"); assert_eq!(config.routes.len(), 2); let route_map = config.routes.get("topic/map").unwrap(); assert_eq!( route_map.programs, vec![ Program::new(vec!["/bin/program1"]), Program::new(vec!["/bin/program2", "arg"]), Program::new(vec!["/bin/program3", "arg"]), ] ); assert_eq!(route_map.qos, Some(QoS::ExactlyOnce)); let route_seq = config.routes.get("topic/seq").unwrap(); assert_eq!( route_seq.programs, vec![ Program::new(vec!["/bin/program4", "arg"]), Program::new_with_timeout(vec!["/bin/program5"], Duration::from_secs_f64(1.2)), ] ); assert_eq!(route_seq.qos, None); } #[test] fn load_minimal_config() { let config: Config = toml::from_str("[routes]").expect("Failed to parse minimal config"); assert_eq!(config.host, default_host()); assert_eq!(config.port, default_port()); assert_eq!(config.qos, default_qos()); assert_eq!(config.id, default_id()); assert_eq!(config.timeout, default_timeout()); assert!(config.credentials.is_none()); assert!(config.routes.is_empty()); } #[test] fn load_route_seq() { let toml_str = r#" [routes] "some/topic" = [["/foo/bar"], ["/baz/qux", "arg"]] "#; let config: Config = toml::from_str(toml_str).unwrap(); let route = config.routes.get("some/topic").unwrap(); assert_eq!( route.programs, vec![ Program::new(vec!["/foo/bar"]), Program::new(vec!["/baz/qux", "arg"]) ] ); assert_eq!(route.qos, None); } #[test] fn load_route_map() { let toml_str = r#" [routes] "topic/with_qos" = { programs = [["/foo/bar", "arg"]], qos = "at-least-once" } "topic/without_qos" = { programs = [["/baz/qux"]] } "#; let config: Config = toml::from_str(toml_str).unwrap(); let route_with_qos = config.routes.get("topic/with_qos").unwrap(); assert_eq!( route_with_qos.programs, vec![Program::new(vec!["/foo/bar", "arg"])] ); assert_eq!(route_with_qos.qos, Some(QoS::AtLeastOnce)); let route_without_qos = config.routes.get("topic/without_qos").unwrap(); assert_eq!( route_without_qos.programs, vec![Program::new(vec!["/baz/qux"])] ); assert_eq!(route_without_qos.qos, None); } #[test] fn load_timeout() { let config_int: Config = toml::from_str("timeout = 10\n[routes]").unwrap(); assert_eq!(config_int.timeout, Duration::from_secs(10)); let config_float: Config = toml::from_str("timeout = 2.5\n[routes]").unwrap(); assert_eq!(config_float.timeout, Duration::from_secs_f64(2.5)); let config_zero: Config = toml::from_str("timeout = 0\n[routes]").unwrap(); assert_eq!(config_zero.timeout, Duration::MAX); let config_zero: Config = toml::from_str("timeout = 0.0\n[routes]").unwrap(); assert_eq!(config_zero.timeout, Duration::MAX); } #[test] fn load_timeout_negative() { let result = toml::from_str::("timeout = -10\n[routes]"); assert!(result.is_err()); let result = toml::from_str::("timeout = -1.0\n[routes]"); assert!(result.is_err()); } #[test] fn load_qos() { let toml_str = r#" [routes] "at-most-once" = { programs = [], qos = "at-most-once" } "at-least-once" = { programs = [], qos = "at-least-once" } "exactly-once" = { programs = [], qos = "exactly-once" } "#; let config: Config = toml::from_str(toml_str).unwrap(); assert_eq!( config.routes.get("at-most-once").unwrap().qos.unwrap(), QoS::AtMostOnce ); assert_eq!( config.routes.get("at-least-once").unwrap().qos.unwrap(), QoS::AtLeastOnce ); assert_eq!( config.routes.get("exactly-once").unwrap().qos.unwrap(), QoS::ExactlyOnce ); } }