Line data Source code
1 : #include "triggeralgs/ProtoDUNEBSMWindow/treelitemodel.hpp"
2 :
3 : namespace triggeralgs {
4 :
5 : // Implement functions of TreeliteModelBase base class
6 : // that are common to all compiled GBDT models
7 :
8 : const int32_t TreeliteModelBase::num_class[1] = { 1 };
9 :
10 0 : TreeliteModelBase::TreeliteModelBase(int numTargets, int maxNumClass)
11 0 : : N_TARGET(numTargets), MAX_N_CLASS(maxNumClass) {}
12 :
13 0 : int32_t TreeliteModelBase::get_num_target(void) const {
14 0 : return N_TARGET;
15 : }
16 :
17 0 : void TreeliteModelBase::get_num_class(int32_t* out) const {
18 0 : for (int i = 0; i < N_TARGET; ++i) {
19 0 : out[i] = TreeliteModelBase::num_class[i];
20 : }
21 0 : }
22 :
23 0 : const char* TreeliteModelBase::get_threshold_type(void) const {
24 0 : return "float32";
25 : }
26 0 : const char* TreeliteModelBase::get_leaf_output_type(void) const {
27 0 : return "float32";
28 : }
29 :
30 0 : void TreeliteModelBase::postprocess(float* result) const {
31 : // sigmoid
32 0 : const float alpha = (float)1;
33 0 : for (size_t i = 0; i < N_TARGET * MAX_N_CLASS; ++i) {
34 0 : result[i] = (float)(1) / ((float)(1) + expf(-alpha * result[i]));
35 : }
36 0 : }
37 :
38 : } // end namespace triggeralgs
|