// SPDX-FileCopyrightText: 2025 Tomasz Kramkowski // SPDX-License-Identifier: GPL-3.0-or-later use std::{ collections::HashMap, fmt, fs::File, io::Read, os::unix::fs::PermissionsExt, path::Path, process, time::Duration, }; use anyhow::bail; use rumqttc::{AsyncClient, EventLoop, MqttOptions, QoS}; use serde::{ de::{self, Visitor}, Deserialize, Deserializer, }; use crate::PROGRAM; #[derive(Deserialize, Debug)] pub struct Credentials { pub username: String, pub password: String, } fn default_host() -> String { "localhost".to_string() } fn default_port() -> u16 { 1883 } fn default_qos() -> QoS { QoS::ExactlyOnce } fn default_id() -> String { PROGRAM.to_string() } fn default_timeout() -> Duration { Duration::from_secs(60) } #[allow(clippy::enum_variant_names)] #[derive(Deserialize, Debug)] #[serde(remote = "QoS", rename_all = "kebab-case")] #[repr(u8)] pub enum QoSDef { AtMostOnce = 0, AtLeastOnce = 1, ExactlyOnce = 2, } pub fn deserialize_qos_opt<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { #[derive(Deserialize)] struct Helper(#[serde(with = "QoSDef")] QoS); let helper = Option::deserialize(deserializer)?; Ok(helper.map(|Helper(external)| external)) } #[derive(Debug)] pub struct Route { // TODO: Figure out a way to allow arbitrary unix paths (arbitrary // non-unicode) without base64 pub programs: Vec>, pub qos: Option, } impl<'de> Deserialize<'de> for Route { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { struct VecOrRoute; impl<'de> Visitor<'de> for VecOrRoute { type Value = Route; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("map or seq") } fn visit_seq(self, seq: A) -> Result where A: serde::de::SeqAccess<'de>, { let vec: Vec> = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; Ok(Route { programs: vec, qos: None, }) } fn visit_map(self, map: A) -> Result where A: serde::de::MapAccess<'de>, { #[derive(Deserialize)] struct RouteHelper { programs: Vec>, #[serde(default, deserialize_with = "deserialize_qos_opt")] qos: Option, } let helper: RouteHelper = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; Ok(Route { programs: helper.programs, qos: helper.qos, }) } } deserializer.deserialize_any(VecOrRoute) } } pub fn deserialize_timeout<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { struct DurationVisitor; impl<'de> de::Visitor<'de> for DurationVisitor { type Value = Duration; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a positive number") } fn visit_i64(self, v: i64) -> Result where E: de::Error, { if v < 0 { return Err(de::Error::invalid_value( de::Unexpected::Signed(v), &"a non-negative number", )); } if v == 0 { Ok(Duration::MAX) } else { Ok(Duration::from_secs(v as u64)) } } fn visit_f64(self, v: f64) -> Result where E: de::Error, { if v < 0.0 { return Err(de::Error::invalid_value( de::Unexpected::Float(v), &"a non-negative number", )); } if v == 0.0 { Ok(Duration::MAX) } else { Ok(Duration::from_secs_f64(v)) } } } deserializer.deserialize_any(DurationVisitor) } #[derive(Deserialize, Debug)] pub struct Config { #[serde(default = "default_host")] pub host: String, #[serde(default = "default_port")] pub port: u16, #[serde(with = "QoSDef", default = "default_qos")] pub qos: QoS, #[serde(default = "default_timeout", deserialize_with = "deserialize_timeout")] pub timeout: Duration, pub credentials: Option, #[serde(default = "default_id")] pub id: String, pub routes: HashMap, } impl Config { pub fn mqtt_client(&self) -> (AsyncClient, EventLoop) { let client_id = format!("{}_{}", self.id, process::id()); let mut options = MqttOptions::new(client_id, &self.host, self.port); if let Some(credentials) = &self.credentials { options.set_credentials(&credentials.username, &credentials.password); } // TODO: Make configurable options.set_keep_alive(Duration::from_secs(5)); options.set_max_packet_size(10 * 1024 * 1024, 10 * 1024 * 1024); AsyncClient::new(options, 10) } } pub fn load>(path: P) -> anyhow::Result { let mut f = File::open(path)?; let mut config = String::new(); f.read_to_string(&mut config)?; let config: Config = toml::from_str(&config)?; if config.credentials.is_some() { let mode = f.metadata()?.permissions().mode(); if mode & 0o044 != 0o000 { bail!("Config file contains credentials while being group or world readable."); } } Ok(config) }