diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 105 | ||||
-rw-r--r-- | src/main.rs | 9 |
2 files changed, 97 insertions, 17 deletions
diff --git a/src/config.rs b/src/config.rs index 00790bd..e70893b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,18 +2,16 @@ // SPDX-License-Identifier: GPL-3.0-or-later use std::{ - collections::HashMap, - fs::File, - io::Read, - os::unix::fs::PermissionsExt, - path::Path, - process, - time::Duration, + collections::HashMap, fmt, fs::File, io::Read, os::unix::fs::PermissionsExt, path::Path, + process, time::Duration, }; use anyhow::bail; -use rumqttc::{AsyncClient, EventLoop, MqttOptions}; -use serde::Deserialize; +use rumqttc::{AsyncClient, EventLoop, MqttOptions, QoS}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, +}; use crate::PROGRAM; @@ -31,22 +29,105 @@ fn default_port() -> u16 { 1883 } +fn default_qos() -> QoS { + QoS::ExactlyOnce +} + fn default_id() -> String { PROGRAM.to_string() } +#[allow(clippy::enum_variant_names)] +#[derive(Deserialize, Debug)] +#[serde(remote = "QoS", rename_all = "kebab-case")] +#[repr(u8)] +pub enum QoSDef { + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} + +pub fn deserialize_qos_opt<'de, D>(deserializer: D) -> Result<Option<QoS>, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Deserialize)] + struct Helper(#[serde(with = "QoSDef")] QoS); + + let helper = Option::deserialize(deserializer)?; + Ok(helper.map(|Helper(external)| external)) +} + +#[derive(Debug)] +pub struct Route { + // TODO: Figure out a way to allow arbitrary unix paths (arbitrary + // non-unicode) without base64 + pub programs: Vec<Vec<String>>, + pub qos: Option<QoS>, +} + +impl<'de> Deserialize<'de> for Route { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + struct VecOrRoute; + + impl<'de> Visitor<'de> for VecOrRoute { + type Value = Route; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("map or seq") + } + + fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + let vec: Vec<Vec<String>> = + Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + Ok(Route { + programs: vec, + qos: None, + }) + } + + fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error> + where + A: serde::de::MapAccess<'de>, + { + #[derive(Deserialize)] + struct RouteHelper { + programs: Vec<Vec<String>>, + #[serde(default, deserialize_with = "deserialize_qos_opt")] + qos: Option<QoS>, + } + + let helper: RouteHelper = + Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + Ok(Route { + programs: helper.programs, + qos: helper.qos, + }) + } + } + + deserializer.deserialize_any(VecOrRoute) + } +} + #[derive(Deserialize, Debug)] pub struct Config { #[serde(default = "default_host")] pub host: String, #[serde(default = "default_port")] pub port: u16, + #[serde(with = "QoSDef", default = "default_qos")] + pub qos: QoS, pub credentials: Option<Credentials>, #[serde(default = "default_id")] pub id: String, - // TODO: Figure out a way to allow arbitrary unix paths (arbitrary - // non-unicode) without base64 - pub routes: HashMap<String, Vec<Vec<String>>>, + pub routes: HashMap<String, Route>, } impl Config { diff --git a/src/main.rs b/src/main.rs index 0ba13f4..84ee406 100644 --- a/src/main.rs +++ b/src/main.rs @@ -73,9 +73,8 @@ async fn main() -> anyhow::Result<()> { let conf = config::load(&conf_path) .with_context(|| format!("Failed to load config: {:?}", &conf_path))?; let (client, mut event_loop) = conf.mqtt_client(); - for topic in conf.routes.keys() { - // TODO: Configurable subscription QoS - if let Err(e) = client.subscribe(topic, QoS::ExactlyOnce).await { + for (topic, route) in conf.routes.iter() { + if let Err(e) = client.subscribe(topic, route.qos.unwrap_or(conf.qos)).await { eprintln!("warning: Failed to subscribe to '{topic}': {e:?}"); } } @@ -83,11 +82,11 @@ async fn main() -> anyhow::Result<()> { let notification = event_loop.poll().await; match notification? { Incoming(Packet::Publish(p)) => { - for (topic, programs) in conf.routes.iter() { + for (topic, route) in conf.routes.iter() { if !topic_match(&topic, &p.topic) { continue; } - for program in programs { + for program in &route.programs { // TODO: Switch to moro_local to avoid this ewwyness let program = program.clone(); let p = p.clone(); |