aboutsummaryrefslogtreecommitdiffstats
path: root/src/main.rs
blob: e319955373cbd9f3e82ac744b73cdcee66f1cc16 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// SPDX-FileCopyrightText: 2025 Tomasz Kramkowski <tomasz@kramkow.ski>
// SPDX-License-Identifier: GPL-3.0-or-later

use std::{
    ffi::OsString,
    os::unix::process::ExitStatusExt,
    path::PathBuf,
    process::{ExitStatus, Stdio},
    rc::Rc,
    time::Duration,
};

use anyhow::Context;
use log::{debug, error, trace, warn};
use moro_local::Scope;
use rumqttc::{Event::Incoming, Packet, Publish, QoS};
use tokio::{io::AsyncWriteExt, process::Command, time::timeout};

use crate::config::Program;

mod config;
mod mqtt;

const PROGRAM: &str = "mqttr";

async fn run(program: &[OsString], message: &Publish) -> anyhow::Result<ExitStatus> {
    debug!("Starting program {program:?} for message {message:?}");
    let mut command = Command::new(&program[0]);
    command
        .args(&program[1..])
        .arg(&message.topic)
        .arg(format!("{}", message.dup as u8))
        .arg(format!("{}", message.qos as u8))
        .arg(format!("{}", message.retain as u8));
    if message.qos == QoS::AtLeastOnce || message.qos == QoS::ExactlyOnce {
        command.arg(format!("{}", message.pkid));
    }
    let mut proc = command.stdin(Stdio::piped()).spawn()?;
    trace!(
        "Started program {program:?} with PID {}",
        proc.id().expect("missing PID")
    );
    let mut stdin = proc.stdin.take().context("No stdin")?;
    stdin.write_all(&message.payload).await?;
    drop(stdin);
    let result = proc.wait().await?;
    Ok(result)
}

fn run_route_programs<'a, R>(
    scope: &'a Scope<'a, '_, R>,
    programs: &'a [Program],
    message: &Publish,
    default_timeout: Duration,
) {
    for program in programs.iter() {
        let p = message.clone();
        scope.spawn(async move {
            // TODO: BUG: This won't guarentee the process gets
            // killed. kill_on_drop itself also has problems.
            // Need to handle this properly manually.
            // TODO: Also should use cancellation tokens.
            match timeout(
                program.timeout.unwrap_or(default_timeout),
                run(&program.command, &p),
            )
            .await
            {
                Err(_) => error!("Execution of {program:?} for message {p:?} timed out"),
                Ok(Err(e)) => error!("error: Failed to run {program:?}: {e:?}"),
                Ok(Ok(c)) if !c.success() => {
                    if let Some(code) = c.code() {
                        if code != 0 {
                            warn!("Program exited with non-zero exit code: {code}")
                        } else {
                            debug!("Program exited successfully.");
                        }
                    } else if let Some(signal) = c.signal() {
                        let core_dumped = if c.core_dumped() {
                            " (core dumped)"
                        } else {
                            ""
                        };
                        warn!("Program received signal: {signal}{core_dumped}");
                    }
                }
                Ok(Ok(_)) => (),
            }
        });
    }
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
    let mut conf_path: PathBuf = option_env!("SYSCONFDIR").unwrap_or("/usr/local/etc").into();
    conf_path.push(format!("{PROGRAM}.toml"));
    let conf = config::load(&conf_path)
        .with_context(|| format!("Failed to load config: {:?}", &conf_path))?;
    stderrlog::new()
        .color(stderrlog::ColorChoice::Never)
        .module(module_path!())
        .verbosity(conf.log.level)
        .timestamp(if conf.log.timestamps {
            stderrlog::Timestamp::Millisecond
        } else {
            stderrlog::Timestamp::Off
        })
        .init()
        .unwrap();
    // TODO: This will print creds
    trace!("Configuration: {conf:?}");
    let (client, mut event_loop) = conf.mqtt_client();
    for (topic, route) in conf.routes.iter() {
        if let Err(e) = client.subscribe(topic, route.qos.unwrap_or(conf.qos)).await {
            warn!("Failed to subscribe to '{topic}': {e:?}");
        } else {
            debug!("Subscribed to: '{topic}'");
        }
    }
    moro_local::async_scope!(|scope| -> anyhow::Result<()> {
        loop {
            let notification = event_loop.poll().await;
            if let Incoming(Packet::Publish(p)) = notification? {
                debug!("Received message: {p:?}");
                let p = Rc::new(p);
                for (topic, route) in conf.routes.iter() {
                    if !mqtt::topic_match(topic, &p.topic) {
                        continue;
                    }
                    debug!("Message {p:?} matched topic {topic}");
                    run_route_programs(scope, &route.programs, &p, conf.timeout);
                }
            }
        }
    })
    .await
}