diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 133 | ||||
-rw-r--r-- | src/main.rs | 7 |
2 files changed, 127 insertions, 13 deletions
diff --git a/src/config.rs b/src/config.rs index 07be6db..e032da5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -62,11 +62,67 @@ where Ok(helper.map(|Helper(external)| external)) } -#[derive(Debug)] -pub struct Route { +#[derive(Debug, PartialEq, Clone)] +pub struct Program { // TODO: Figure out a way to allow arbitrary unix paths (arbitrary // non-unicode) without base64 - pub programs: Vec<Vec<String>>, + pub command: Vec<String>, + pub timeout: Option<Duration>, +} + +impl<'de> Deserialize<'de> for Program { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + struct VecOrProgram; + + impl<'de> Visitor<'de> for VecOrProgram { + type Value = Program; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("map or seq") + } + + fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + let vec: Vec<String> = + Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + Ok(Program { + command: vec, + timeout: None, + }) + } + + fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error> + where + A: serde::de::MapAccess<'de>, + { + #[derive(Deserialize)] + struct Helper { + command: Vec<String>, + #[serde(default, deserialize_with = "deserialize_timeout_opt")] + timeout: Option<Duration>, + } + + let helper: Helper = + Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?; + Ok(Program { + command: helper.command, + timeout: helper.timeout, + }) + } + } + + deserializer.deserialize_any(VecOrProgram) + } +} + +#[derive(Debug)] +pub struct Route { + pub programs: Vec<Program>, pub qos: Option<QoS>, } @@ -88,7 +144,7 @@ impl<'de> Deserialize<'de> for Route { where A: serde::de::SeqAccess<'de>, { - let vec: Vec<Vec<String>> = + let vec: Vec<Program> = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?; Ok(Route { programs: vec, @@ -102,7 +158,7 @@ impl<'de> Deserialize<'de> for Route { { #[derive(Deserialize)] struct RouteHelper { - programs: Vec<Vec<String>>, + programs: Vec<Program>, #[serde(default, deserialize_with = "deserialize_qos_opt")] qos: Option<QoS>, } @@ -171,6 +227,17 @@ where deserializer.deserialize_any(DurationVisitor) } +pub fn deserialize_timeout_opt<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Deserialize)] + struct Helper(#[serde(deserialize_with = "deserialize_timeout")] Duration); + + let helper = Option::deserialize(deserializer)?; + Ok(helper.map(|Helper(external)| external)) +} + #[derive(Deserialize, Debug)] pub struct Config { #[serde(default = "default_host")] @@ -221,6 +288,22 @@ mod tests { use rumqttc::QoS; use std::time::Duration; + impl Program { + fn new(command: Vec<&str>) -> Self { + Program { + command: command.into_iter().map(str::to_string).collect(), + timeout: None, + } + } + + fn new_with_timeout(command: Vec<&str>, timeout: Duration) -> Self { + Program { + command: command.into_iter().map(str::to_string).collect(), + timeout: Some(timeout), + } + } + } + #[test] fn load_full_config() { let toml_str = r#" @@ -235,8 +318,15 @@ mod tests { password = "testpassword" [routes] - "topic/map" = { programs = [["/bin/program1"], ["/bin/program2", "arg"]], qos = "exactly-once" } - "topic/seq" = [["/bin/program3", "arg"]] + "topic/map" = { programs = [ + ["/bin/program1"], + ["/bin/program2", "arg"], + { command = ["/bin/program3", "arg"]}, + ], qos = "exactly-once" } + "topic/seq" = [ + ["/bin/program4", "arg"], + { command = ["/bin/program5"], timeout = 1.2 }, + ] "#; let config: Config = toml::from_str(toml_str).expect("Failed to parse full config"); @@ -256,12 +346,22 @@ mod tests { let route_map = config.routes.get("topic/map").unwrap(); assert_eq!( route_map.programs, - vec![vec!["/bin/program1"], vec!["/bin/program2", "arg"]] + vec![ + Program::new(vec!["/bin/program1"]), + Program::new(vec!["/bin/program2", "arg"]), + Program::new(vec!["/bin/program3", "arg"]), + ] ); assert_eq!(route_map.qos, Some(QoS::ExactlyOnce)); let route_seq = config.routes.get("topic/seq").unwrap(); - assert_eq!(route_seq.programs, vec![vec!["/bin/program3", "arg"]]); + assert_eq!( + route_seq.programs, + vec![ + Program::new(vec!["/bin/program4", "arg"]), + Program::new_with_timeout(vec!["/bin/program5"], Duration::from_secs_f64(1.2)), + ] + ); assert_eq!(route_seq.qos, None); } @@ -290,7 +390,10 @@ mod tests { assert_eq!( route.programs, - vec![vec!["/foo/bar"], vec!["/baz/qux", "arg"]] + vec![ + Program::new(vec!["/foo/bar"]), + Program::new(vec!["/baz/qux", "arg"]) + ] ); assert_eq!(route.qos, None); } @@ -306,11 +409,17 @@ mod tests { let config: Config = toml::from_str(toml_str).unwrap(); let route_with_qos = config.routes.get("topic/with_qos").unwrap(); - assert_eq!(route_with_qos.programs, vec![vec!["/foo/bar", "arg"]]); + assert_eq!( + route_with_qos.programs, + vec![Program::new(vec!["/foo/bar", "arg"])] + ); assert_eq!(route_with_qos.qos, Some(QoS::AtLeastOnce)); let route_without_qos = config.routes.get("topic/without_qos").unwrap(); - assert_eq!(route_without_qos.programs, vec![vec!["/baz/qux"]]); + assert_eq!( + route_without_qos.programs, + vec![Program::new(vec!["/baz/qux"])] + ); assert_eq!(route_without_qos.qos, None); } diff --git a/src/main.rs b/src/main.rs index 9ebd103..68f62dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,7 +90,12 @@ async fn main() -> anyhow::Result<()> { let program = program.clone(); let p = p.clone(); tokio::spawn(async move { - match timeout(conf.timeout, run(&program, &p)).await { + match timeout( + program.timeout.unwrap_or(conf.timeout), + run(&program.command, &p), + ) + .await + { Err(_) => eprintln!( "error: Execution of {program:?} for message {p:?} timed out" ), |