Skip to content

Commit

Permalink
Crossvalidation type added for k-fold method, bug fixed (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeySKiselev authored Sep 28, 2019
1 parent 4881997 commit ef10157
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 21 deletions.
22 changes: 16 additions & 6 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Unirand
A JavaScript module for generating seeded random distributions and its statistical analysis.

Implemented in pure JavaScript with no dependencies, designed to work in Node.js and fully asynchronous, tested *with ~600 tests*.
Implemented in pure JavaScript with no dependencies, designed to work in Node.js and fully asynchronous, tested *with 600+ tests*.

[Supported distributions](./core/methods/)

Expand All @@ -17,7 +17,7 @@ const unirand = require('unirand')
```

### PRNG
Unirand supports different PRNGs: *default JS generator, tuchei sedded generator*. By default unirand uses **tuchei** generator.
Unirand supports different PRNGs: *default JS generator, tuchei seeded generator*. By default unirand uses **tuchei** generator.
Our seeded generator supports *seed*, *random*, *next* methods.
A name of current using PRNG is stored in:
```javascript
Expand Down Expand Up @@ -158,8 +158,8 @@ Sample method is **3 times faster** for arrays and **7 times faster** for string
### k-fold
Splits array into *k* subarrays. Requires at least 2 arguments: array itself and *k*. Also supports *options*.

- *type*: output type, **list** (default) for output like `[<fold>, <fold>, <fold>, ...]`, **set** for output like `{0: <fold>, 1: <fold>, 2: <fold>, ...}`
- *derange*: items will be shuffled as *random permutation* (default) or *random derangement*
- *type*: output type, **list** (default) for output like `[<fold>, <fold>, <fold>, ...]`, **set** for output like `{0: <fold>, 1: <fold>, 2: <fold>, ...}`, **crossvalidation** for output like `[{test: <fold>, data: <remaining folds>}, ...]`
- *derange*: items will be shuffled as *random permutation* (default, `derange: false`) or *random derangement* (`derange: true`)
```javascript
const kfold = unirand.kfold;
kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3); // [ [ 9, 8, 2, 10 ], [ 1, 7, 3 ], [ 4, 5, 6 ] ]
Expand All @@ -168,7 +168,17 @@ kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3); // [ [ 9, 8, 2, 10 ], [ 1, 7, 3 ], [
kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3, {
type: 'set',
derange: true
}); // { '0': [ 8, 10, 7, 1 ], '1': [ 6, 4, 9 ], '2': [ 5, 2, 3 ] }
});
// { '0': [ 8, 10, 7, 1 ], '1': [ 6, 4, 9 ], '2': [ 5, 2, 3 ] }

// cross validation
kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3, {
type: 'crossvalidation',
derange: true
})
// [ { id: 0, test: [ 5, 6, 9, 7 ], data: [ 4, 1, 10, 2, 8, 3 ] },
// { id: 1, test: [ 4, 1, 10 ], data: [ 5, 6, 9, 7, 2, 8, 3 ] },
// { id: 2, test: [ 2, 8, 3 ], data: [ 5, 6, 9, 7, 4, 1, 10 ] } ]
```
For permutation unirand uses seeded PRNG. With *seed* k-fold will always return same result.

Expand Down Expand Up @@ -197,7 +207,7 @@ Winsorization is the transformation of statistics by limiting extreme values in
Parameters:
- *input*: array of numbers
- *limits*: single number, represent same value trimming value from left and right (should be 0 < limit < 0.5), or an array \[left trim value, right trim value\] (values should be 0 < left trim value < right trim value < 1)
- *mutate*: <true|false> value (default *true*). If true - mutate ofiginal array, otherwise - no
- *mutate*: <true|false> value (default *true*). If true - mutate original array, otherwise - no

```javascript
const winsorize = unirand.winsorize;
Expand Down
92 changes: 79 additions & 13 deletions core/array_manipulation/kfold.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import ArrayManipulation from './base';
import Shuffle from './shuffle';

import type { KFoldOptions, RandomArrayNumberString, RandomArrayStringObject, RandomArrayString } from '../types';
import type { KFoldOptions, KFoldCrossValidation, RandomArrayNumberString,
RandomArrayStringObject, RandomArrayString } from '../types';
import type { IKFold, IShuffle } from '../interfaces';

class KFold extends ArrayManipulation implements IKFold {
Expand All @@ -28,7 +29,7 @@ class KFold extends ArrayManipulation implements IKFold {
getKFold(input: RandomArrayNumberString<any>, k: number, options: KFoldOptions = {
type: 'list',
derange: false
}): RandomArrayStringObject<any> {
}): RandomArrayStringObject<any> | KFoldCrossValidation {
this._validateInput(input, false);

if (typeof k !== 'number') {
Expand All @@ -39,24 +40,89 @@ class KFold extends ArrayManipulation implements IKFold {
throw new Error('k-fold: Parameter "k" should be greater then 0 and less input.length');
}

let result: RandomArrayStringObject<any>;

if (options.type === 'list') {
result = [];
} else if (options.type === 'set') {
result = {};
} else {
throw new Error('k-fold: Wrong output type, should be "list" or "set"');
}

const folds: Array<number> = this._getFolds(input.length, k);
let permutedInput: RandomArrayString<number | string>;
if (!options.derange) {
permutedInput = this._shuffle.getPermutation(input);
} else {
permutedInput = this._shuffle.getDerangement(input);
}

if (options.type === 'list') {
return this._getListSetKFold(permutedInput, folds, []);
} else if (options.type === 'set') {
return this._getListSetKFold(permutedInput, folds, {});
} else if (options.type === 'crossvalidation') {
return this._getCrossValidationKFold(permutedInput, folds);
}
throw new Error('k-fold: Wrong output type, should be "list", "set" or "crossvalidation"');
}

const folds: Array<number> = this._getFolds(input.length, k);
/**
* Generates kfold output for "crossvalidation" type
* @param {*} permutedInput
* @param {*} folds
* @param {*} result
*/
_getCrossValidationKFold(
permutedInput: RandomArrayString<number | string>,
folds: Array<number>
): KFoldCrossValidation {
const result = [];
const listFolds: RandomArrayStringObject<any> = this._getListSetKFold(permutedInput, folds, []);
for (let i = 0; i < listFolds.length; i += 1) {
result.push({
id: i,
test: listFolds[i].slice(),
data: this._generateData(listFolds, i)
});
}

return result;
}

/**
* Genarates data for crossvalidation
* Collects all data from all folds except fold[i]
* @param {RandomArrayStringObject<any>} listFolds
* @param {number} i
* @private
*/
_generateData(listFolds: RandomArrayStringObject<any>, i: number): Array<RandomArrayStringObject<any>> {
const result: Array<RandomArrayStringObject<any>> = [];
for (let j = 0; j < i; j += 1) {
this._addSubData(listFolds[j], result);
}
for (let j = i + 1; j < listFolds.length; j += 1) {
this._addSubData(listFolds[j], result);
}

return result;
}

/**
* @param {RandomArrayStringObject<any>} listFolds
* @param {Array<RandomArrayStringObject<any>>} result
* @private
*/
_addSubData(listFolds: RandomArrayStringObject<any>, result: Array<RandomArrayStringObject<any>>): void {
for (let k = 0; k < listFolds.length; k += 1) {
result.push(listFolds[k]);
}
}

/**
* Generates kfold output for "list" and "set" types
* @param {RandomArrayString<number | string>} permutedInput
* @param {Array<number>} folds
* @param {RandomArrayStringObject<any>} result
* @private
*/
_getListSetKFold(
permutedInput: RandomArrayString<number | string>,
folds: Array<number>,
result: RandomArrayStringObject<any>
): RandomArrayStringObject<any> {
let pindex: number = 0;
let subResult: RandomArrayNumberString<any> = [];

Expand Down
11 changes: 11 additions & 0 deletions core/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,14 @@ export type NumberString = number | string;
* Array<number> or number
*/
export type RandomArrayNumber = RandomArray | number;

/**
* kfold crossvalidation
*/
export type KFoldCrossValidationItem = {
id: number,
test: RandomArrayStringObject<any>,
data: RandomArrayStringObject<any>
};

export type KFoldCrossValidation = Array<KFoldCrossValidationItem>;
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "unirand",
"version": "2.5.1",
"version": "2.5.2",
"description": "Random numbers and Distributions generation",
"main": "./lib/index.js",
"scripts": {
Expand Down
69 changes: 69 additions & 0 deletions test/array_manipulation.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -786,5 +786,74 @@ describe('Array manipulation methods', () => {
}
done();
});
it('should have correct output structure for type "crossvalidation"', () => {
const kfold = new KFold();
const randomInput = generateInput();
const res = kfold.getKFold(randomInput, 10, {
type: 'crossvalidation'
});

expect(res.length).to.be.equal(10);
expect(Array.isArray(res)).to.be.equal(true);
expect(res[0].id).to.be.equal(0);
expect(res[0].test).to.be.not.equal(undefined);
expect(res[0].data).to.be.not.equal(undefined);
expect(Object.keys(res[0]).length).to.be.equal(3);
expect(Array.isArray(res[0].test)).to.be.equal(true);
expect(Array.isArray(res[0].data)).to.be.equal(true);
for (let i = 0; i < res.length; i += 1) {
expect(res[i].test.length + res[i].data.length).to.be.equal(randomInput.length);
}
});
it('should generate correct data for type "crossvalidation"', function(done) {
this.timeout(480000);
const kfold = new KFold();
let input = [];
let res;
const checkExistance = (data, test) => {
let ht = {};
let fail = false;
for (let i = 0; i < test.length; i += 1) {
ht[test[i]] = 1;
}
for (let i = 0; i < data.length; i += 1) {
if (ht[data[i]]) {
fail = true;
break;
}
}
expect(fail).to.be.equal(false);
};

const checkUniqueness = (data, test) => {
const ht = {};
for (let i = 0; i < data.length; i += 1) {
ht[data[i]] = 1;
}

for (let i = 0; i < test.length; i += 1) {
ht[test[i]] = 1;
}

expect(Object.keys(ht).length).to.be.equal(data.length + test.length);
};

for (let i = 0; i < 5000; i += 1) {
input[i] = i;
}

for (let i = 0; i < 400; i += 1) {
res = kfold.getKFold(input, 200, {
type: 'crossvalidation'
});

for (let j = 0; j < res.length; j += 1) {
checkUniqueness(res[j].data, res[j].test);
checkExistance(res[j].data, res[j].test);
}
}

done();
});
});
});
5 changes: 4 additions & 1 deletion test/testBuild.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@
let unirand = require('../lib');

unirand.seed();
console.log(unirand.laplace(1, 2).distributionSync(4));
console.log(unirand.kfold([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 3, {
derange: true,
type: 'crossvalidation'
}));

0 comments on commit ef10157

Please sign in to comment.