sum/product optimizations

This commit is contained in:
Adrian Mariano 2024-06-25 19:24:48 -04:00
parent c6e58b926b
commit 054052144e
2 changed files with 68 additions and 52 deletions

108
math.scad
View file

@ -721,16 +721,17 @@ function deltas(v, wrap=false) =
// cumsum([1,2,3]); // returns [1,3,6]
// cumsum([[1,2,3], [3,4,5], [5,6,7]]); // returns [[1,2,3], [4,6,8], [9,12,15]]
function cumsum(v) =
v==[] ? [] :
assert(is_consistent(v), "The input is not consistent." )
len(v)<=1 ? v :
_cumsum(v,_i=1,_acc=[v[0]]);
function _cumsum(v,_i=0,_acc=[]) =
_i>=len(v) ? _acc :
_cumsum( v, _i+1, [ each _acc, _acc[len(_acc)-1] + v[_i] ] );
[for (a = v[0],
i = 1
;
i <= len(v)
;
a = i<len(v) ? a+v[i] : a,
i = i+1)
a];
// Function: product()
// Synopsis: Returns the multiplicative product of a list of values.
// Topics: Math, Statistics
@ -739,24 +740,35 @@ function _cumsum(v,_i=0,_acc=[]) =
// x = product(v);
// Description:
// Returns the product of all entries in the given list.
// If passed a list of vectors of same dimension, returns a vector of products of each part.
// If passed a list of square matrices, returns the resulting product matrix.
// If passed a list of vectors of same length, returns a vector of the component-wise products of the input.
// If passed a list of square matrices, returns the resulting product matrix. Matrices are multiplied in the order they appear in the list.
// Arguments:
// v = The list to get the product of.
// Example:
// product([2,3,4]); // returns 24.
// product([[1,2,3], [3,4,5], [5,6,7]]); // returns [15, 48, 105]
function product(v) =
assert( is_vector(v) || is_matrix(v) || ( is_matrix(v[0],square=true) && is_consistent(v)),
"Invalid input.")
_product(v, 1, v[0]);
function _product(v, i=0, _tot) =
i>=len(v) ? _tot :
_product( v,
i+1,
( is_vector(v[i])? v_mul(_tot,v[i]) : _tot*v[i] ) );
function product(list,right=true) =
list==[] ? [] :
is_matrix(list) ?
[for (a = list[0],
i = 1
;
i <= len(list)
;
a = i<len(list) ? v_mul(a,list[i]) : 0,
i = i+1)
if (i==len(list)) a][0]
:
assert(is_vector(list) || (is_matrix(list[0],square=true) && is_consistent(list)),
"Input must be a vector, a list of vectors, or a list of matrices.")
[for (a = list[0],
i = 1
;
i <= len(list)
;
a = i<len(list) ? a*list[i] : 0,
i = i+1)
if (i==len(list)) a][0];
// Function: cumprod()
@ -777,37 +789,29 @@ function _product(v, i=0, _tot) =
// cumprod([1,3,5]); // returns [1,3,15]
// cumprod([2,2,2]); // returns [2,4,8]
// cumprod([[1,2,3], [3,4,5], [5,6,7]])); // returns [[1, 2, 3], [3, 8, 15], [15, 48, 105]]
function cumprod(list,right=false) =
is_vector(list) ? _cumprod(list) :
assert(is_consistent(list), "Input must be a consistent list of scalars, vectors or square matrices")
assert(is_bool(right))
is_matrix(list[0]) ? assert(len(list[0])==len(list[0][0]), "Matrices must be square") _cumprod(list,right)
: _cumprod_vec(list);
function _cumprod(v,right,_i=0,_acc=[]) =
_i==len(v) ? _acc :
_cumprod(
v, right, _i+1,
concat(
_acc,
[
_i==0 ? v[_i]
: right? _acc[len(_acc)-1]*v[_i]
: v[_i]*_acc[len(_acc)-1]
]
)
);
function _cumprod_vec(v,_i=0,_acc=[]) =
_i==len(v) ? _acc :
_cumprod_vec(
v, _i+1,
concat(
_acc,
[_i==0 ? v[_i] : v_mul(_acc[len(_acc)-1],v[_i])]
)
);
list==[] ? [] :
is_matrix(list) ?
[for (a = list[0],
i = 1
;
i <= len(list)
;
a = i<len(list) ? v_mul(a,list[i]) : 0,
i = i+1)
a]
:
assert(is_vector(list) || (is_matrix(list[0],square=true) && is_consistent(list)),
"Input must be a listector, a list of listectors, or a list of matrices.")
[for (a = list[0],
i = 1
;
i <= len(list)
;
a = i<len(list) ? (right ? a*list[i] : list[i]*a) : 0,
i = i+1)
a];
// Function: convolve()

View file

@ -371,6 +371,7 @@ test_deltas();
module test_product() {
assert_equal(product([]),[]);
assert_equal(product([2,3,4]), 24);
assert_equal(product([[1,2,3], [3,4,5], [5,6,7]]), [15, 48, 105]);
m1 = [[2,3,4],[4,5,6],[6,7,8]];
@ -613,6 +614,7 @@ module test_cumprod(){
assert_equal(cumprod([]),[]);
assert_equal(cumprod([[2,3],[4,5],[6,7]]), [[2,3],[8,15],[48,105]]);
assert_equal(cumprod([[5,6,7]]),[[5,6,7]]);
assert_equal(cumprod([up(5),down(5)]), [up(5),IDENT]);
assert_equal(cumprod([
[[1,2],[3,4]],
[[-4,5],[6,4]],
@ -623,6 +625,16 @@ module test_cumprod(){
[[11,12],[18,28]],
[[45,24],[98,132]]
]);
assert_equal(cumprod([
[[1,2],[3,4]],
[[-4,5],[6,4]],
[[9,-3],[4,3]]
],right=true),
[
[[1,2],[3,4]],
[[8, 13],[12,31]],
[[124, 15],[232,57]]
]);
assert_equal(cumprod([[[1,2],[3,4]]]), [[[1,2],[3,4]]]);
}
test_cumprod();