var lib = require('./lib2');
module.exports.ctree=ctree;
ctree.help = "ctree(input=lib.supermatrix.alloc(1,1),depth=3,label_count=2,branch_count=2,input_width=512,hidden_width=25,batch_size=1)";
function ctree(input=lib.supermatrix.alloc(1,1),depth=3,label_count=2,branch_count=2,input_width=512,hidden_width=25,batch_size=1) {
if(depth<1) throw new Error('We need some depth');
if(!input[0]) {
input[0] = lib.matrix.alloc(input_width,batch_size);
}
var isLeaf = depth == 1;
var outWidths = isLeaf?[label_count]:[label_count+branch_count,hidden_width];
//var softMaxWidth = isLeaf?label_count?(label_count+branch_count);
//var firstLayerOP = lib.supermatrix.alloc(1+hidden_in.length,2); //Number of input kinds by number of output kinds
var layerOP = lib.supermatrix.alloc(input.width,outWidths.length); //Number of input kinds by number of output kinds;
var index=-1;
for(var outWidth of outWidths) {
for(var input_sub of input) {
layerOP[++index]=lib.matrix.random(input_sub.width,outWidth);
}
}
var layerOut = lib.supermatrix.alloc(outWidths.length,1);
index=-1;
for(outWidth of outWidths) {
//console.log({outWidth,batch_size});
layerOut[++index] = lib.matrix.alloc(outWidth,batch_size);
}
var branches=null;
if(!isLeaf) {
var nextInput = input.concat(layerOut[1]);
nextInput.width = input.width+1;
branches = range(branch_count).map(()=>ctree(nextInput,depth-1,label_count,branch_count,input_width,hidden_width,batch_size));
}
function forward() {
//We assume the client filled input;
lib.supermatrix.sideways.multiply(input,layerOP,layerOut);
if(layerOut[1]) lib.matrix.relu.inplace(layerOut[1]); //Apply activaton function to our hiddens
lib.matrix.softmax(layerOut[0]); //Apply softmax activation to labels and branches
var branch;
if(branches) {
for(branch of branches) {
branch.forward();
}
}
if(branches) { //Consolidate branch results into labels
for(var i=0;i<batch_size;++i) { //Iterate rows if there are any
for(var j=0;j<branch_count;++j) { //For each label beyond the label_count is a branch weight
var weight = layerOut[0][i*layerOut[0].width+j+label_count]; //Each label beyond the label_count is a branch weight
branch = branches[j];
for(var k=0;k<label_count;++k) { //For each label
layerOut[0][k]+=weight*branch.layerOut[0][k];
}
}
}
}
}
return {input,layerOP,layerOut,forward,isLeaf,branches,labels:layerOut[0],hiddenData:layerOut[1]};
}