AnnouncementsFunnyVideosMusicAncapsTechnologyEconomicsPrivacyGIFSCringeAnarchyFilmPicsThemesIdeas4MatrixAskMatrixHelpTop Subs
2
var lib = require('./lib2');
module.exports.ctree=ctree;
ctree.help = "ctree(input=lib.supermatrix.alloc(1,1),depth=3,label_count=2,branch_count=2,input_width=512,hidden_width=25,batch_size=1)";
function ctree(input=lib.supermatrix.alloc(1,1),depth=3,label_count=2,branch_count=2,input_width=512,hidden_width=25,batch_size=1) {
 if(depth<1) throw new Error('We need some depth');
 if(!input[0]) {
  input[0] = lib.matrix.alloc(input_width,batch_size);
 }
 var isLeaf = depth == 1;
 var outWidths = isLeaf?[label_count]:[label_count+branch_count,hidden_width];
 //var softMaxWidth = isLeaf?label_count?(label_count+branch_count);
 //var firstLayerOP = lib.supermatrix.alloc(1+hidden_in.length,2); //Number of input kinds by number of output kinds
 var layerOP = lib.supermatrix.alloc(input.width,outWidths.length); //Number of input kinds by number of output kinds;
 var index=-1;
 for(var outWidth of outWidths) {
  for(var input_sub of input) {
   layerOP[++index]=lib.matrix.random(input_sub.width,outWidth);
  }
 }
 var layerOut = lib.supermatrix.alloc(outWidths.length,1);
 index=-1;
 for(outWidth of outWidths) {
  //console.log({outWidth,batch_size});
  layerOut[++index] = lib.matrix.alloc(outWidth,batch_size);
 }
 var branches=null;
 if(!isLeaf) {
  var nextInput = input.concat(layerOut[1]);
  nextInput.width = input.width+1;
  branches = range(branch_count).map(()=>ctree(nextInput,depth-1,label_count,branch_count,input_width,hidden_width,batch_size));
 }
 function forward() {
  //We assume the client filled input;
  lib.supermatrix.sideways.multiply(input,layerOP,layerOut);
  if(layerOut[1]) lib.matrix.relu.inplace(layerOut[1]); //Apply activaton function to our hiddens
  lib.matrix.softmax(layerOut[0]); //Apply softmax activation to labels and branches
  var branch;
  if(branches) {
   for(branch of branches) {
    branch.forward();
   }
  }
  if(branches) { //Consolidate branch results into labels
   for(var i=0;i<batch_size;++i) { //Iterate rows if there are any
    for(var j=0;j<branch_count;++j) { //For each label beyond the label_count is a branch weight
     var weight = layerOut[0][i*layerOut[0].width+j+label_count]; //Each label beyond the label_count is a branch weight
     branch = branches[j];
     for(var k=0;k<label_count;++k) { //For each label
      layerOut[0][k]+=weight*branch.layerOut[0][k];
     }
    }
   }
  }
 }
 return {input,layerOP,layerOut,forward,isLeaf,branches,labels:layerOut[0],hiddenData:layerOut[1]}; 
}
Comment preview