Source: eclairjs/mllib/tree/model/GradientBoostedTreesModel.js

  1. /**
  2. * Created by billreed on 4/11/16.
  3. *//*
  4. * Copyright 2016 IBM Corp.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. (function () {
  19. var JavaWrapper = require(EclairJS_Globals.NAMESPACE + '/JavaWrapper');
  20. var Logger = require(EclairJS_Globals.NAMESPACE + '/Logger');
  21. var Utils = require(EclairJS_Globals.NAMESPACE + '/Utils');
  22. /**
  23. * Represents a gradient boosted trees model.
  24. *
  25. * @param algo algorithm for the ensemble model, either Classification or Regression
  26. * @param trees tree ensembles
  27. * @param treeWeights tree ensemble weights
  28. * @classdesc
  29. */
  30. /**
  31. * @param {Algo} algo
  32. * @param {module:eclairjs/mllib/tree/model.DecisionTreeModel[]} trees
  33. * @param {number[]} treeWeights
  34. * @returns {??}
  35. * @class
  36. * @memberof module:eclairjs/mllib/tree/model
  37. */
  38. var GradientBoostedTreesModel = function (algo, trees, treeWeights) {
  39. this.logger = Logger.getLogger("GradientBoostedTreesModel_js");
  40. var jvmObject;
  41. if (arguments[0] instanceof org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) {
  42. jvmObject = arguments[0];
  43. } else {
  44. jvmObject = new org.apache.spark.mllib.tree.model.GradientBoostedTreesModel(Utils.unwrapObject(arguments[0]),
  45. Utils.unwrapObject(trees),
  46. treeWeights
  47. );
  48. }
  49. JavaWrapper.call(this, jvmObject);
  50. };
  51. GradientBoostedTreesModel.prototype = Object.create(JavaWrapper.prototype);
  52. GradientBoostedTreesModel.prototype.constructor = GradientBoostedTreesModel;
  53. /**
  54. * @param {module:eclairjs.SparkContext} sc Spark context used to save model data.
  55. * @param {string} path Path specifying the directory in which to save this model.
  56. * If the directory already exists, this method throws an exception.
  57. */
  58. GradientBoostedTreesModel.prototype.save = function (sc, path) {
  59. var sc_uw = Utils.unwrapObject(sc);
  60. this.getJavaObject().save(sc_uw.sc(), path);
  61. };
  62. /**
  63. * Method to compute error or loss for every iteration of gradient boosting.
  64. * @param {module:eclairjs.RDD} data RDD of {@link LabeledPoint}
  65. * @param {module:eclairjs/mllib/tree/loss.Loss} loss evaluation metric.
  66. * containing the first i+1 trees
  67. * @returns {number[]} an array with index i having the losses or errors for the ensemble
  68. */
  69. GradientBoostedTreesModel.prototype.evaluateEachIteration = function (data, loss) {
  70. throw "not implemented by ElairJS";
  71. // var data_uw = Utils.unwrapObject(data);
  72. // var loss_uw = Utils.unwrapObject(loss);
  73. // return this.getJavaObject().evaluateEachIteration(data_uw,loss_uw);
  74. };
  75. GradientBoostedTreesModel.prototype.predict = function (features) {
  76. var features_uw = Utils.unwrapObject(features);
  77. return this.getJavaObject().predict(features_uw);
  78. };
  79. /**
  80. * Print the full model to a string.
  81. * @returns {string}
  82. */
  83. GradientBoostedTreesModel.prototype.toDebugString = function () {
  84. return this.getJavaObject().toDebugString();
  85. };
  86. /**
  87. * Print a summary of the model.
  88. * @returns {string}
  89. */
  90. GradientBoostedTreesModel.prototype.toString = function () {
  91. return this.getJavaObject().toString();
  92. };
  93. //
  94. // static methods
  95. //
  96. /**
  97. * @param {module:eclairjs.SparkContext} sc Spark context used for loading model files.
  98. * @param {string} path Path specifying the directory to which the model was saved.
  99. * @returns {module:eclairjs/mllib/tree/model.GradientBoostedTreesModel} Model instance
  100. */
  101. GradientBoostedTreesModel.load = function (sc, path) {
  102. var sc_uw = Utils.unwrapObject(sc);
  103. var javaObject = org.apache.spark.mllib.tree.model.GradientBoostedTreesModel.load(sc_uw.sc(), path);
  104. return new GradientBoostedTreesModel(javaObject);
  105. };
  106. module.exports = GradientBoostedTreesModel;
  107. })();