Skip to content

Commit 06cd3bf

Browse files
author
Sean Friedowitz
authored
Merge pull request #305 from CitrineInformatics/parallel-iteration-only
Store regular Seq and cast to Par during training/transform only
2 parents 0ae69bd + 534e0c9 commit 06cd3bf

File tree

5 files changed

+16
-20
lines changed

5 files changed

+16
-20
lines changed

src/main/scala/io/citrine/lolo/bags/BaggedModel.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package io.citrine.lolo.bags
33
import breeze.linalg.DenseMatrix
44
import io.citrine.lolo.api.Model
55

6-
import scala.collection.parallel.immutable.ParSeq
6+
import scala.collection.parallel.CollectionConverters._
77

88
/** A model holding a parallel sequence of models and the sample counts used to train them. */
99
trait BaggedModel[+T] extends Model[T] {
1010

1111
/** Models in the ensemble trained on subsets of the training data. */
12-
def ensembleModels: ParSeq[Model[T]]
12+
def ensembleModels: Seq[Model[T]]
1313

1414
override def transform(inputs: Seq[Vector[Any]]): BaggedPrediction[T]
1515

@@ -31,7 +31,7 @@ trait BaggedModel[+T] extends Model[T] {
3131
}
3232

3333
case class BaggedRegressionModel(
34-
ensembleModels: ParSeq[Model[Double]],
34+
ensembleModels: Seq[Model[Double]],
3535
Nib: Vector[Vector[Int]],
3636
rescaleRatio: Double = 1.0,
3737
disableBootstrap: Boolean = false,
@@ -42,7 +42,7 @@ case class BaggedRegressionModel(
4242
assert(inputs.forall(_.size == inputs.head.size))
4343

4444
val bias = biasModel.map(_.transform(inputs).expected)
45-
val ensemblePredictions = ensembleModels.map(model => model.transform(inputs)).seq
45+
val ensemblePredictions = ensembleModels.par.map(model => model.transform(inputs)).seq
4646

4747
if (inputs.size == 1) {
4848
// In the special case of a single prediction on a real value, emit an optimized prediction class
@@ -65,11 +65,11 @@ case class BaggedRegressionModel(
6565
}
6666
}
6767

68-
case class BaggedClassificationModel[T](ensembleModels: ParSeq[Model[T]]) extends BaggedModel[T] {
68+
case class BaggedClassificationModel[T](ensembleModels: Seq[Model[T]]) extends BaggedModel[T] {
6969

7070
override def transform(inputs: Seq[Vector[Any]]): BaggedClassificationPrediction[T] = {
7171
assert(inputs.forall(_.size == inputs.head.size))
72-
val ensemblePredictions = ensembleModels.map(model => model.transform(inputs)).seq
72+
val ensemblePredictions = ensembleModels.par.map(model => model.transform(inputs)).seq
7373
BaggedClassificationPrediction(ensemblePredictions)
7474
}
7575
}

src/main/scala/io/citrine/lolo/bags/BaggedTrainingResult.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@ package io.citrine.lolo.bags
33
import io.citrine.lolo.api.{Model, TrainingResult, TrainingRow}
44
import io.citrine.lolo.stats.metrics.{ClassificationMetrics, RegressionMetrics}
55

6-
import scala.collection.parallel.immutable.ParSeq
7-
86
/** The result of training a [[Bagger]] to produce a [[BaggedModel]]. */
97
sealed trait BaggedTrainingResult[+T] extends TrainingResult[T] {
108

119
override def model: BaggedModel[T]
1210
}
1311

1412
case class RegressionBaggerTrainingResult(
15-
ensembleModels: ParSeq[Model[Double]],
13+
ensembleModels: Seq[Model[Double]],
1614
Nib: Vector[Vector[Int]],
1715
trainingData: Seq[TrainingRow[Double]],
1816
override val featureImportance: Option[Vector[Double]],
@@ -50,7 +48,7 @@ case class RegressionBaggerTrainingResult(
5048
}
5149

5250
case class ClassificationBaggerTrainingResult[T](
53-
ensembleModels: ParSeq[Model[T]],
51+
ensembleModels: Seq[Model[T]],
5452
Nib: Vector[Vector[Int]],
5553
trainingData: Seq[TrainingRow[T]],
5654
override val featureImportance: Option[Vector[Double]],

src/main/scala/io/citrine/lolo/bags/Bagger.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import io.citrine.lolo.stats.StatsUtils
77
import io.citrine.random.Random
88

99
import scala.collection.parallel.CollectionConverters._
10-
import scala.collection.parallel.immutable.ParVector
1110

1211
sealed trait Bagger[T] extends Learner[T] {
1312

@@ -66,6 +65,7 @@ sealed trait Bagger[T] extends Learner[T] {
6665
val meta = baseLearner.train(weightedTrainingData, thisRng)
6766
(meta.model, meta.featureImportance)
6867
}
68+
.seq
6969
.unzip
7070

7171
// Average the feature importance
@@ -195,7 +195,7 @@ object Bagger {
195195
* @tparam T type of label data for the models
196196
*/
197197
protected[bags] case class BaggedEnsemble[+T](
198-
models: ParVector[Model[T]],
198+
models: Vector[Model[T]],
199199
Nib: Vector[Vector[Int]],
200200
averageImportance: Option[Vector[Double]]
201201
)

src/main/scala/io/citrine/lolo/bags/BaggerHelper.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package io.citrine.lolo.bags
33
import io.citrine.lolo.stats.{MathUtils, StatsUtils}
44
import io.citrine.lolo.api.{Model, TrainingRow}
55

6-
import scala.collection.parallel.immutable.ParSeq
7-
86
/**
97
* Helper class to subsume shared functionality of [[RegressionBagger]] and [[MultiTaskBagger]].
108
*
@@ -15,7 +13,7 @@ import scala.collection.parallel.immutable.ParSeq
1513
* @param uncertaintyCalibration whether to apply empirical uncertainty calibration
1614
*/
1715
protected[bags] case class BaggerHelper(
18-
models: ParSeq[Model[Double]],
16+
models: Seq[Model[Double]],
1917
trainingData: Seq[TrainingRow[Double]],
2018
Nib: Vector[Vector[Int]],
2119
useJackknife: Boolean,

src/main/scala/io/citrine/lolo/bags/MultiTaskBagger.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import io.citrine.lolo.stats.StatsUtils
1616
import io.citrine.random.Random
1717
import io.citrine.lolo.stats.metrics.{ClassificationMetrics, RegressionMetrics}
1818

19-
import scala.collection.parallel.immutable.ParSeq
2019
import scala.collection.parallel.CollectionConverters._
2120

2221
/**
@@ -88,6 +87,7 @@ case class MultiTaskBagger(
8887
val meta = method.train(weightedTrainingData, thisRng)
8988
(meta.model, meta.featureImportance)
9089
}
90+
.seq
9191
.unzip
9292

9393
val averageImportance: Option[Vector[Double]] = importances
@@ -144,7 +144,7 @@ case class MultiTaskBagger(
144144
* @param rescaleRatios sequence of uncertainty calibration ratios for each label
145145
*/
146146
case class MultiTaskBaggedTrainingResult(
147-
ensembleModels: ParSeq[MultiTaskModel],
147+
ensembleModels: Seq[MultiTaskModel],
148148
Nib: Vector[Vector[Int]],
149149
trainingData: Seq[TrainingRow[Vector[Any]]],
150150
override val featureImportance: Option[Vector[Double]],
@@ -217,7 +217,7 @@ case class MultiTaskBaggedTrainingResult(
217217
val thisLabelModels = ensembleModels.map(_.models(i))
218218
if (isReal) {
219219
BaggedRegressionModel(
220-
thisLabelModels.asInstanceOf[ParSeq[Model[Double]]],
220+
thisLabelModels.asInstanceOf[Seq[Model[Double]]],
221221
Nib = Nib,
222222
rescaleRatio = rescaleRatios(i),
223223
biasModel = biasModels(i)
@@ -238,7 +238,7 @@ case class MultiTaskBaggedTrainingResult(
238238
* @param rescaleRatios sequence of uncertainty calibration ratios for each label
239239
*/
240240
case class MultiTaskBaggedModel(
241-
ensembleModels: ParSeq[MultiTaskModel],
241+
ensembleModels: Seq[MultiTaskModel],
242242
Nib: Vector[Vector[Int]],
243243
biasModels: Seq[Option[Model[Double]]],
244244
rescaleRatios: Seq[Double]
@@ -250,7 +250,7 @@ case class MultiTaskBaggedModel(
250250
val thisLabelsModels = ensembleModels.map(_.models(i))
251251
if (realLabels(i)) {
252252
BaggedRegressionModel(
253-
thisLabelsModels.asInstanceOf[ParSeq[Model[Double]]],
253+
thisLabelsModels.asInstanceOf[Seq[Model[Double]]],
254254
Nib = Nib,
255255
rescaleRatio = rescaleRatios(i),
256256
biasModel = biasModels(i)

0 commit comments

Comments
 (0)