Merge pull request #306 from adrianVmariano/master

Add cumprod function and regression test
This commit is contained in:
Revar Desmera 2020-10-23 16:13:03 -07:00 committed by GitHub
commit 0e823111dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 0 deletions

View file

@ -617,6 +617,44 @@ function _product(v, i=0, _tot) =
// Function: cumprod()
// Description:
// Returns a list where each item is the cumulative product of all items up to and including the corresponding entry in the input list.
// If passed an array of vectors, returns a list of elementwise vector products. If passed a list of square matrices returns matrix
// products multiplying in the order items appear in the list.
// Arguments:
// list = The list to get the product of.
// Example:
// cumprod([1,3,5]); // returns [1,3,15]
// cumprod([2,2,2]); // returns [2,4,8]
// cumprod([[1,2,3], [3,4,5], [5,6,7]])); // returns [[1, 2, 3], [3, 8, 15], [15, 48, 105]]
function cumprod(list) =
is_vector(list) ? _cumprod(list) :
assert(is_consistent(list), "Input must be a consistent list of scalars, vectors or square matrices")
is_matrix(list[0]) ? assert(len(list[0])==len(list[0][0]), "Matrices must be square") _cumprod(list)
: _cumprod_vec(list);
function _cumprod(v,_i=0,_acc=[]) =
_i==len(v) ? _acc :
_cumprod(
v, _i+1,
concat(
_acc,
[_i==0 ? v[_i] : _acc[len(_acc)-1]*v[_i]]
)
);
function _cumprod_vec(v,_i=0,_acc=[]) =
_i==len(v) ? _acc :
_cumprod_vec(
v, _i+1,
concat(
_acc,
[_i==0 ? v[_i] : vmul(_acc[len(_acc)-1],v[_i])]
)
);
// Function: outer_product()
// Usage:
// x = outer_product(u,v);

View file

@ -836,6 +836,28 @@ module test_linear_solve(){
test_linear_solve();
module test_cumprod(){
assert_equal(cumprod([1,2,3,4]), [1,2,6,24]);
assert_equal(cumprod([4]), [4]);
assert_equal(cumprod([]),[]);
assert_equal(cumprod([[2,3],[4,5],[6,7]]), [[2,3],[8,15],[48,105]]);
assert_equal(cumprod([[5,6,7]]),[[5,6,7]]);
assert_equal(cumprod([
[[1,2],[3,4]],
[[-4,5],[6,4]],
[[9,-3],[4,3]]
]),
[
[[1,2],[3,4]],
[[8,13],[12,31]],
[[124,15],[232,57]]
]);
assert_equal(cumprod([[[1,2],[3,4]]]), [[[1,2],[3,4]]]);
}
test_cumprod();
module test_outer_product(){
assert_equal(outer_product([1,2,3],[4,5,6]), [[4,5,6],[8,10,12],[12,15,18]]);
assert_equal(outer_product([1,2],[4,5,6]), [[4,5,6],[8,10,12]]);