Extracts metrics from a fitted table

ml_metrics_regression

Description

The function works best when passed a tbl_spark created by ml_predict(). The output tbl_spark will contain the correct variable types and format that the given Spark model “evaluator” expects.

Usage

ml_metrics_regression(
  x,
  truth,
  estimate = prediction,
  metrics = c("rmse", "rsq", "mae"),
  ...
)

Arguments

Arguments Description
x A tbl_spark containing the estimate (prediction) and the truth (value of what actually happened)
truth The name of the column from x that contains the value of what actually happened
estimate The name of the column from x that contains the prediction. Defaults to prediction, since it is the default that ml_predict() uses.
metrics A character vector with the metrics to calculate. For regression models the possible values are: rmse (Root mean squared error), mse (Mean squared error),rsq (R squared), mae (Mean absolute error), and var (Explained variance). Defaults to: rmse, rsq, mae
Optional arguments; currently unused.

Details

The ml_metrics family of functions implement Spark’s evaluate closer to how the yardstick package works. The functions expect a table containing the truth and estimate, and return a tibble with the results. The tibble has the same format and variable names as the output of the yardstick functions.

Examples



sc <- spark_connect("local")
tbl_iris <- copy_to(sc, iris)
iris_split <- sdf_random_split(tbl_iris, training = 0.5, test = 0.5)
training <- iris_split$training
reg_formula <- "Sepal_Length ~ Sepal_Width + Petal_Length + Petal_Width"
model <- ml_generalized_linear_regression(training, reg_formula)
tbl_predictions <- ml_predict(model, iris_split$test)
tbl_predictions %>%
  ml_metrics_regression(Sepal_Length)