aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2025-07-04 12:55:32 +0200
committerTomasz Kramkowski <tomasz@kramkow.ski>2025-07-04 12:55:32 +0200
commit68839f01cd982f03d7ff95d3180cfae8534dc3eb (patch)
treee6220aa6035cdbf13905d4f479cdd07d700b3dc1 /src
parentce71a662f977c9dd3790c62620ebd0568276b05f (diff)
downloadmqttr-68839f01cd982f03d7ff95d3180cfae8534dc3eb.tar.gz
mqttr-68839f01cd982f03d7ff95d3180cfae8534dc3eb.tar.xz
mqttr-68839f01cd982f03d7ff95d3180cfae8534dc3eb.zip
Configurable QoS
Diffstat (limited to 'src')
-rw-r--r--src/config.rs105
-rw-r--r--src/main.rs9
2 files changed, 97 insertions, 17 deletions
diff --git a/src/config.rs b/src/config.rs
index 00790bd..e70893b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -2,18 +2,16 @@
// SPDX-License-Identifier: GPL-3.0-or-later
use std::{
- collections::HashMap,
- fs::File,
- io::Read,
- os::unix::fs::PermissionsExt,
- path::Path,
- process,
- time::Duration,
+ collections::HashMap, fmt, fs::File, io::Read, os::unix::fs::PermissionsExt, path::Path,
+ process, time::Duration,
};
use anyhow::bail;
-use rumqttc::{AsyncClient, EventLoop, MqttOptions};
-use serde::Deserialize;
+use rumqttc::{AsyncClient, EventLoop, MqttOptions, QoS};
+use serde::{
+ de::{self, Visitor},
+ Deserialize, Deserializer,
+};
use crate::PROGRAM;
@@ -31,22 +29,105 @@ fn default_port() -> u16 {
1883
}
+fn default_qos() -> QoS {
+ QoS::ExactlyOnce
+}
+
fn default_id() -> String {
PROGRAM.to_string()
}
+#[allow(clippy::enum_variant_names)]
+#[derive(Deserialize, Debug)]
+#[serde(remote = "QoS", rename_all = "kebab-case")]
+#[repr(u8)]
+pub enum QoSDef {
+ AtMostOnce = 0,
+ AtLeastOnce = 1,
+ ExactlyOnce = 2,
+}
+
+pub fn deserialize_qos_opt<'de, D>(deserializer: D) -> Result<Option<QoS>, D::Error>
+where
+ D: Deserializer<'de>,
+{
+ #[derive(Deserialize)]
+ struct Helper(#[serde(with = "QoSDef")] QoS);
+
+ let helper = Option::deserialize(deserializer)?;
+ Ok(helper.map(|Helper(external)| external))
+}
+
+#[derive(Debug)]
+pub struct Route {
+ // TODO: Figure out a way to allow arbitrary unix paths (arbitrary
+ // non-unicode) without base64
+ pub programs: Vec<Vec<String>>,
+ pub qos: Option<QoS>,
+}
+
+impl<'de> Deserialize<'de> for Route {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ struct VecOrRoute;
+
+ impl<'de> Visitor<'de> for VecOrRoute {
+ type Value = Route;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("map or seq")
+ }
+
+ fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::SeqAccess<'de>,
+ {
+ let vec: Vec<Vec<String>> =
+ Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
+ Ok(Route {
+ programs: vec,
+ qos: None,
+ })
+ }
+
+ fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
+ where
+ A: serde::de::MapAccess<'de>,
+ {
+ #[derive(Deserialize)]
+ struct RouteHelper {
+ programs: Vec<Vec<String>>,
+ #[serde(default, deserialize_with = "deserialize_qos_opt")]
+ qos: Option<QoS>,
+ }
+
+ let helper: RouteHelper =
+ Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
+ Ok(Route {
+ programs: helper.programs,
+ qos: helper.qos,
+ })
+ }
+ }
+
+ deserializer.deserialize_any(VecOrRoute)
+ }
+}
+
#[derive(Deserialize, Debug)]
pub struct Config {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
+ #[serde(with = "QoSDef", default = "default_qos")]
+ pub qos: QoS,
pub credentials: Option<Credentials>,
#[serde(default = "default_id")]
pub id: String,
- // TODO: Figure out a way to allow arbitrary unix paths (arbitrary
- // non-unicode) without base64
- pub routes: HashMap<String, Vec<Vec<String>>>,
+ pub routes: HashMap<String, Route>,
}
impl Config {
diff --git a/src/main.rs b/src/main.rs
index 0ba13f4..84ee406 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -73,9 +73,8 @@ async fn main() -> anyhow::Result<()> {
let conf = config::load(&conf_path)
.with_context(|| format!("Failed to load config: {:?}", &conf_path))?;
let (client, mut event_loop) = conf.mqtt_client();
- for topic in conf.routes.keys() {
- // TODO: Configurable subscription QoS
- if let Err(e) = client.subscribe(topic, QoS::ExactlyOnce).await {
+ for (topic, route) in conf.routes.iter() {
+ if let Err(e) = client.subscribe(topic, route.qos.unwrap_or(conf.qos)).await {
eprintln!("warning: Failed to subscribe to '{topic}': {e:?}");
}
}
@@ -83,11 +82,11 @@ async fn main() -> anyhow::Result<()> {
let notification = event_loop.poll().await;
match notification? {
Incoming(Packet::Publish(p)) => {
- for (topic, programs) in conf.routes.iter() {
+ for (topic, route) in conf.routes.iter() {
if !topic_match(&topic, &p.topic) {
continue;
}
- for program in programs {
+ for program in &route.programs {
// TODO: Switch to moro_local to avoid this ewwyness
let program = program.clone();
let p = p.clone();