DUNE-DAQ
DUNE Trigger and Data Acquisition software
Loading...
Searching...
No Matches
triggeralgs::CompiledModelInterface Class Reference

#include <CompiledModelInterface.hpp>

Public Member Functions

 CompiledModelInterface (int nbatch)
 
 ~CompiledModelInterface ()
 
int GetNumFeatures ()
 
void ModelWarmUp (Entry *input)
 
void Predict (Entry *input, float *result)
 
bool Classify (const float *result, float &bdt_threshold)
 

Protected Attributes

std::unique_ptr< TreeliteModelBasemodel_ptr
 
int num_batch
 

Detailed Description

Definition at line 17 of file CompiledModelInterface.hpp.

Constructor & Destructor Documentation

◆ CompiledModelInterface()

triggeralgs::CompiledModelInterface::CompiledModelInterface ( int nbatch)

Definition at line 7 of file CompiledModelInterface.cpp.

7 : num_batch(nbatch) {
8 model_ptr = std::make_unique<TreelitePDHDModel>();
9}
std::unique_ptr< TreeliteModelBase > model_ptr

◆ ~CompiledModelInterface()

triggeralgs::CompiledModelInterface::~CompiledModelInterface ( )

Definition at line 11 of file CompiledModelInterface.cpp.

11{}

Member Function Documentation

◆ Classify()

bool triggeralgs::CompiledModelInterface::Classify ( const float * result,
float & bdt_threshold )

Definition at line 33 of file CompiledModelInterface.cpp.

33 {
34 for (uint64_t rid = 0; rid < num_batch; rid++) {
35 if (result[rid] > bdt_threshold) {
36 return true;
37 }
38 }
39 return false;
40}

◆ GetNumFeatures()

int triggeralgs::CompiledModelInterface::GetNumFeatures ( )

Definition at line 13 of file CompiledModelInterface.cpp.

13 {
14 return model_ptr->get_num_feature();
15}

◆ ModelWarmUp()

void triggeralgs::CompiledModelInterface::ModelWarmUp ( Entry * input)

Definition at line 17 of file CompiledModelInterface.cpp.

17 {
18 // Warm the BDT up here
19 float result[num_batch];
20 for (int rid = 0; rid < num_batch; ++rid) {
21 for (int i = 0; i < 100; i++) {
22 model_ptr->predict(input, 0, result);
23 }
24 }
25}

◆ Predict()

void triggeralgs::CompiledModelInterface::Predict ( Entry * input,
float * result )

Definition at line 27 of file CompiledModelInterface.cpp.

27 {
28 for (int rid = 0; rid < num_batch; ++rid) {
29 model_ptr->predict(input, 0, result);
30 }
31}

Member Data Documentation

◆ model_ptr

std::unique_ptr<TreeliteModelBase> triggeralgs::CompiledModelInterface::model_ptr
protected

Definition at line 37 of file CompiledModelInterface.hpp.

◆ num_batch

int triggeralgs::CompiledModelInterface::num_batch
protected

Definition at line 38 of file CompiledModelInterface.hpp.


The documentation for this class was generated from the following files: