DUNE-DAQ
DUNE Trigger and Data Acquisition software
Loading...
Searching...
No Matches
CompiledModelInterface.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5namespace triggeralgs {
6
7CompiledModelInterface::CompiledModelInterface(int nbatch) : num_batch(nbatch) {
8 model_ptr = std::make_unique<TreelitePDHDModel>();
9}
10
12
14 return model_ptr->get_num_feature();
15}
16
18 // Warm the BDT up here
19 float result[num_batch];
20 for (int rid = 0; rid < num_batch; ++rid) {
21 for (int i = 0; i < 100; i++) {
22 model_ptr->predict(input, 0, result);
23 }
24 }
25}
26
27void CompiledModelInterface::Predict(Entry *input, float *result) {
28 for (int rid = 0; rid < num_batch; ++rid) {
29 model_ptr->predict(input, 0, result);
30 }
31}
32
33bool CompiledModelInterface::Classify(const float *result, float &bdt_threshold) {
34 for (uint64_t rid = 0; rid < num_batch; rid++) {
35 if (result[rid] > bdt_threshold) {
36 return true;
37 }
38 }
39 return false;
40}
41
42
43} // namespace triggeralgs
bool Classify(const float *result, float &bdt_threshold)
void Predict(Entry *input, float *result)
std::unique_ptr< TreeliteModelBase > model_ptr