aboutsummaryrefslogtreecommitdiffstats
path: root/klippy/mathutil.py
diff options
context:
space:
mode:
authorKevin O'Connor <kevin@koconnor.net>2018-03-02 18:18:35 -0500
committerKevin O'Connor <kevin@koconnor.net>2018-03-04 09:56:50 -0500
commitfa07be93463d9ce0262ae383259e68eea625bc4d (patch)
treeec028add97cd891fe9ebd85b140c1219e062eaaa /klippy/mathutil.py
parent7290ed5f73f4c649099fec5b13cd7c3f7a29b2d2 (diff)
downloadkutter-fa07be93463d9ce0262ae383259e68eea625bc4d.tar.gz
kutter-fa07be93463d9ce0262ae383259e68eea625bc4d.tar.xz
kutter-fa07be93463d9ce0262ae383259e68eea625bc4d.zip
mathutil: Move coordinate_descent() to new file
Add a new python file (mathutil.py) and move the coordinate_descent() code to it. Signed-off-by: Kevin O'Connor <kevin@koconnor.net>
Diffstat (limited to 'klippy/mathutil.py')
-rw-r--r--klippy/mathutil.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/klippy/mathutil.py b/klippy/mathutil.py
new file mode 100644
index 00000000..d8df3539
--- /dev/null
+++ b/klippy/mathutil.py
@@ -0,0 +1,40 @@
+# Simple math helper functions
+#
+# Copyright (C) 2018 Kevin O'Connor <kevin@koconnor.net>
+#
+# This file may be distributed under the terms of the GNU GPLv3 license.
+import logging
+
+# Helper code that implements coordinate descent
+def coordinate_descent(adj_params, params, error_func):
+ # Define potential changes
+ params = dict(params)
+ dp = {param_name: 1. for param_name in adj_params}
+ # Calculate the error
+ best_err = error_func(params)
+
+ threshold = 0.00001
+ rounds = 0
+
+ while sum(dp.values()) > threshold and rounds < 10000:
+ rounds += 1
+ for param_name in adj_params:
+ orig = params[param_name]
+ params[param_name] = orig + dp[param_name]
+ err = error_func(params)
+ if err < best_err:
+ # There was some improvement
+ best_err = err
+ dp[param_name] *= 1.1
+ continue
+ params[param_name] = orig - dp[param_name]
+ err = error_func(params)
+ if err < best_err:
+ # There was some improvement
+ best_err = err
+ dp[param_name] *= 1.1
+ continue
+ params[param_name] = orig
+ dp[param_name] *= 0.9
+ logging.info("Coordinate descent best_err: %s rounds: %d", best_err, rounds)
+ return params