Line data Source code
1 : #include "triggeralgs/ProtoDUNEBSMWindow/CompiledModelInterface.hpp"
2 :
3 : #include <iostream>
4 :
5 : namespace triggeralgs {
6 :
7 0 : CompiledModelInterface::CompiledModelInterface(int nbatch) : num_batch(nbatch) {
8 0 : model_ptr = std::make_unique<TreelitePDHDModel>();
9 0 : }
10 :
11 0 : CompiledModelInterface::~CompiledModelInterface() {}
12 :
13 0 : int CompiledModelInterface::GetNumFeatures() {
14 0 : return model_ptr->get_num_feature();
15 : }
16 :
17 0 : void CompiledModelInterface::ModelWarmUp(Entry *input) {
18 : // Warm the BDT up here
19 0 : float result[num_batch];
20 0 : for (int rid = 0; rid < num_batch; ++rid) {
21 0 : for (int i = 0; i < 100; i++) {
22 0 : model_ptr->predict(input, 0, result);
23 : }
24 : }
25 0 : }
26 :
27 0 : void CompiledModelInterface::Predict(Entry *input, float *result) {
28 0 : for (int rid = 0; rid < num_batch; ++rid) {
29 0 : model_ptr->predict(input, 0, result);
30 : }
31 0 : }
32 :
33 0 : bool CompiledModelInterface::Classify(const float *result, float &bdt_threshold) {
34 0 : for (uint64_t rid = 0; rid < num_batch; rid++) {
35 0 : if (result[rid] > bdt_threshold) {
36 : return true;
37 : }
38 : }
39 : return false;
40 : }
41 :
42 :
43 : } // namespace triggeralgs
|