Skip to content

Commit 0725e90

Browse files
authored
Merge pull request #964 from JuliaAI/dev
For a 1.2.0 release
2 parents 30687fb + b9e6ac1 commit 0725e90

File tree

10 files changed

+116
-47
lines changed

10 files changed

+116
-47
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "1.1.2"
4+
version = "1.2.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/composition/learning_networks/nodes.jl

+13-7
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,19 @@ function _apply(y_plus, input...; kwargs...)
150150
try
151151
(y.operation)(mach..., raw_args...)
152152
catch exception
153-
@error "Failed "*
154-
"to apply the operation `$(y.operation)` to the machine "*
155-
"$(y.machine), which receives it's data arguments from one or more "*
156-
"nodes in a learning network. Possibly, one of these nodes "*
157-
"is delivering data that is incompatible with the machine's model.\n"*
158-
diagnostics(y, input...; kwargs...)
159-
throw(exception)
153+
diagnostics = MLJBase.diagnostics(y, input...; kwargs...) # defined in sources.jl
154+
if !isempty(mach)
155+
@error "Failed "*
156+
"to apply the operation `$(y.operation)` to the machine "*
157+
"$(y.machine), which receives it's data arguments from one or more "*
158+
"nodes in a learning network. Possibly, one of these nodes "*
159+
"is delivering data that is incompatible "*
160+
"with the machine's model.\n"*diagnostics
161+
else
162+
@error "Failed "*
163+
"to apply the operation `$(y.operation)`."*diagnostics
164+
end
165+
rethrow(exception)
160166
end
161167
end
162168

src/composition/models/pipelines.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -402,15 +402,13 @@ end
402402

403403
# ## Methods to extend a pipeline learning network
404404

405-
# The "front" of a pipeline network, as we grow it, consists of a
406-
# "predict" and a "transform" node. Once the pipeline is complete
407-
# (after a series of `extend` operations - see below) the "transform"
408-
# node is what is used to deliver the output of `transform(pipe)` in
409-
# the exported model, and the "predict" node is what will be used to
410-
# deliver the output of `predict(pipe). Both nodes can be changed by
411-
# `extend` but only the "active" node is propagated. Initially
412-
# "transform" is active and "predict" only becomes active when a
413-
# supervised model is encountered; this change is permanent.
405+
# The "front" of a pipeline network, as we grow it, consists of a "predict" and a
406+
# "transform" node. Once the pipeline is complete (after a series of `extend` operations -
407+
# see below) the "transform" node is what is used to deliver the output of
408+
# `transform(pipe, ...)` in the exported model, and the "predict" node is what will be
409+
# used to deliver the output of `predict(pipe, ...). Both nodes can be changed by `extend`
410+
# but only the "active" node is propagated. Initially "transform" is active and "predict"
411+
# only becomes active when a supervised model is encountered; this change is permanent.
414412
# https://github.com/JuliaAI/MLJClusteringInterface.jl/issues/10
415413

416414
abstract type ActiveNodeOperation end
@@ -587,7 +585,10 @@ end
587585
# component, only its `abstract_type`. See comment at top of page.
588586

589587
MMI.supports_training_losses(pipe::SupervisedPipeline) =
590-
MMI.supports_training_losses(getproperty(pipe, supervised_component_name(pipe)))
588+
MMI.supports_training_losses(supervised_component(pipe))
589+
590+
MMI.reports_feature_importances(pipe::SupervisedPipeline) =
591+
MMI.reports_feature_importances(supervised_component(pipe))
591592

592593
# This trait cannot be defined at the level of types (see previous comment):
593594
function MMI.iteration_parameter(pipe::SupervisedPipeline)
@@ -618,3 +619,18 @@ function MMI.training_losses(pipe::SupervisedPipeline, pipe_report)
618619
report = getproperty(pipe_report, supervised_name)
619620
return training_losses(supervised, report)
620621
end
622+
623+
624+
# ## Feature importances
625+
626+
function feature_importances(pipe::SupervisedPipeline, fitresult, report)
627+
# locate the machine associated with the supervised component:
628+
supervised_name = MLJBase.supervised_component_name(pipe)
629+
predict_node = fitresult.interface.predict
630+
mach = only(MLJBase.machines_given_model(predict_node)[supervised_name])
631+
632+
# To extract the feature_importances, we can't do `feature_importances(mach)` because
633+
# `mach.model` is just a symbol; instead we do:
634+
supervised = MLJBase.supervised_component(pipe)
635+
return feature_importances(supervised, mach.fitresult, mach.report[:fit])
636+
end

src/composition/models/transformed_target_model.jl

+20-5
Original file line numberDiff line numberDiff line change
@@ -237,28 +237,41 @@ const ERR_TT_MISSING_REPORT =
237237
"Cannot find report for `TransformedTargetModel` atomic model, from which "*
238238
"to extract training losses. "
239239

240-
function training_losses(composite::SomeTT, tt_report)
240+
function MMI.training_losses(composite::SomeTT, tt_report)
241241
hasproperty(tt_report, :model) || throw(ERR_TT_MISSING_REPORT)
242242
atomic_report = getproperty(tt_report, :model)
243243
return training_losses(composite.model, atomic_report)
244244
end
245245

246246

247+
# # FEATURE IMPORTANCES
248+
249+
function MMI.feature_importances(composite::SomeTT, fitresult, report)
250+
# locate the machine associated with the supervised component:
251+
predict_node = fitresult.interface.predict
252+
mach = only(MLJBase.machines_given_model(predict_node)[:model])
253+
254+
# To extract the feature_importances, we can't do `feature_importances(mach)` because
255+
# `mach.model` is just a symbol; instead we do:
256+
return feature_importances(composite.model, mach.fitresult, mach.report[:fit])
257+
end
258+
259+
247260
## MODEL TRAITS
248261

249262
MMI.package_name(::Type{<:SomeTT}) = "MLJBase"
250263
MMI.package_license(::Type{<:SomeTT}) = "MIT"
251264
MMI.package_uuid(::Type{<:SomeTT}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
252265
MMI.is_wrapper(::Type{<:SomeTT}) = true
253-
MMI.package_url(::Type{<:SomeTT}) =
254-
"https://github.com/JuliaAI/MLJBase.jl"
266+
MMI.package_url(::Type{<:SomeTT}) = "https://github.com/JuliaAI/MLJBase.jl"
255267

256268
for New in TT_TYPE_EXS
257269
quote
258270
MMI.iteration_parameter(::Type{<:$New{M}}) where M =
259271
MLJBase.prepend(:model, iteration_parameter(M))
260272
end |> eval
261-
for trait in [:input_scitype,
273+
for trait in [
274+
:input_scitype,
262275
:output_scitype,
263276
:target_scitype,
264277
:fit_data_scitype,
@@ -270,8 +283,10 @@ for New in TT_TYPE_EXS
270283
:supports_class_weights,
271284
:supports_online,
272285
:supports_training_losses,
286+
:reports_feature_importances,
273287
:is_supervised,
274-
:prediction_type]
288+
:prediction_type
289+
]
275290
quote
276291
MMI.$trait(::Type{<:$New{M}}) where M = MMI.$trait(M)
277292
end |> eval

src/machines.jl

-8
Original file line numberDiff line numberDiff line change
@@ -808,10 +808,6 @@ julia> fitted_params(mach).logistic_classifier
808808
intercept = 0.0883301599726305,)
809809
```
810810
811-
Additional keys, `machines` and `fitted_params_given_machine`, give a
812-
list of *all* machines in the underlying network, and a dictionary of
813-
fitted parameters keyed on those machines.
814-
815811
See also [`report`](@ref)
816812
817813
"""
@@ -852,10 +848,6 @@ julia> report(mach).linear_binary_classifier
852848
853849
```
854850
855-
Additional keys, `machines` and `report_given_machine`, give a
856-
list of *all* machines in the underlying network, and a dictionary of
857-
reports keyed on those machines.
858-
859851
See also [`fitted_params`](@ref)
860852
861853
"""

src/sources.jl

+16-9
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,21 @@ function diagnostics(X::AbstractNode, input...; kwargs...)
8585
_sources = sources(X)
8686
scitypes = scitype.(raw_args)
8787
mach = X.machine
88-
model = mach.model
89-
_input = input_scitype(model)
90-
_target = target_scitype(model)
91-
_output = output_scitype(model)
88+
89+
table0 = if !isnothing(mach)
90+
model = mach.model
91+
_input = input_scitype(model)
92+
_target = target_scitype(model)
93+
_output = output_scitype(model)
94+
"""
95+
Model ($model):
96+
input_scitype = $_input
97+
target_scitype =$_target
98+
output_scitype =$_output
99+
"""
100+
else
101+
""
102+
end
92103

93104
table1 = "Incoming data:\n"*
94105
"arg of $(X.operation)\tscitype\n"*
@@ -97,11 +108,7 @@ function diagnostics(X::AbstractNode, input...; kwargs...)
97108

98109
table2 = diagnostic_table_sources(X)
99110
return """
100-
Model ($model):
101-
input_scitype = $_input
102-
target_scitype =$_target
103-
output_scitype =$_output
104-
111+
$table0
105112
$table1
106113
$table2"""
107114
end

test/_models/DecisionTree.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor
22

33
import MLJBase
44
import MLJBase: @mlj_model, metadata_pkg, metadata_model
5-
5+
import MLJBase.Tables
66
using ScientificTypes
77

88
using CategoricalArrays
@@ -98,8 +98,11 @@ function MLJBase.fit(model::DecisionTreeClassifier, verbosity::Int, X, y)
9898
#> empty values):
9999

100100
cache = nothing
101-
report = (classes_seen=classes_seen,
102-
print_tree=TreePrinter(tree))
101+
report = (
102+
classes_seen=classes_seen,
103+
print_tree=TreePrinter(tree),
104+
features=Tables.columnnames(Tables.columns(X)) |> collect,
105+
)
103106

104107
return fitresult, cache, report
105108
end
@@ -137,6 +140,17 @@ function MLJBase.predict(model::DecisionTreeClassifier
137140
for i in 1:size(y_probabilities, 1)]
138141
end
139142

143+
MLJBase.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true
144+
145+
function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report)
146+
features = report.features
147+
fi = DecisionTree.impurity_importance(first(fitresult), normalize=true)
148+
fi_pairs = Pair.(features, fi)
149+
# sort descending
150+
sort!(fi_pairs, by= x->-x[2])
151+
152+
return fi_pairs
153+
end
140154

141155
## REGRESSOR
142156

test/composition/models/network_composite.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ end
4848
MLJBase.reporting_operations(::Type{<:ReportingScaler}) = (:transform, )
4949

5050
MLJBase.transform(model::ReportingScaler, _, X) = (
51-
model.alpha*Tables.matrix(X),
51+
Tables.table(model.alpha*Tables.matrix(X)),
5252
(; nrows = size(MLJBase.matrix(X))[1]),
5353
)
5454

@@ -143,7 +143,7 @@ composite = WatermelonComposite(
143143
Set([:scaler, :clusterer, :classifier1, :training_loss, :len])
144144
@test fitr.scaler == (nrows=10,)
145145
@test fitr.clusterer == (labels=['A', 'B', 'C'],)
146-
@test Set(keys(fitr.classifier1)) == Set([:classes_seen, :print_tree])
146+
@test Set(keys(fitr.classifier1)) == Set([:classes_seen, :print_tree, :features])
147147
@test fitr.training_loss isa Real
148148
@test fitr.len == 10
149149

@@ -164,7 +164,7 @@ composite = WatermelonComposite(
164164
Set([:scaler, :clusterer, :classifier1, :finalizer])
165165
@test predictr.scaler == (nrows=5,)
166166
@test predictr.clusterer == (labels=['A', 'B', 'C'],)
167-
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree])
167+
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree, :features])
168168
@test predictr.finalizer == (nrows=5,)
169169

170170
o, predictr = predict(composite, f, selectrows(X, 1:2))
@@ -174,7 +174,7 @@ composite = WatermelonComposite(
174174
Set([:scaler, :clusterer, :classifier1, :finalizer])
175175
@test predictr.scaler == (nrows=2,) # <----------- different
176176
@test predictr.clusterer == (labels=['A', 'B', 'C'],)
177-
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree])
177+
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree, :features])
178178
@test predictr.finalizer == (nrows=2,) # <---------- different
179179

180180
r = MMI.report(composite, Dict(:fit => fitr, :predict=> predictr))

test/composition/models/pipelines.jl

+10
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,16 @@ end
690690
rm(filename)
691691
end
692692

693+
@testset "feature importances" begin
694+
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
695+
pipe = Standardizer |> DecisionTreeClassifier()
696+
@test reports_feature_importances(pipe)
697+
X, y = @load_iris
698+
fitresult, _, report = MLJBase.fit(pipe, 0, X, y)
699+
features = first.(feature_importances(pipe, fitresult, report))
700+
@test Set(features) == Set(keys(X))
701+
end
702+
693703
end # module
694704

695705
true

test/composition/models/transformed_target_model.jl

+9
Original file line numberDiff line numberDiff line change
@@ -177,5 +177,14 @@ y = rand(5)
177177
@test training_losses(mach) == ones(5)
178178
end
179179

180+
@testset "feature_importances" begin
181+
X, y = @load_iris
182+
atom = DecisionTreeClassifier()
183+
model = TransformedTargetModel(atom, transformer=identity, inverse=identity)
184+
@test reports_feature_importances(model)
185+
fitresult, _, rpt = MMI.fit(model, 0, X, y)
186+
@test Set(first.(feature_importances(model, fitresult, rpt))) == Set(keys(X))
187+
end
188+
180189
end
181190
true

0 commit comments

Comments
 (0)