LCOV - code coverage report
Current view: top level - triggeralgs/src/ProtoDUNEBSMWindow - CompiledModelInterface.cpp (source / functions) Coverage Total Hit
Test: code.result Lines: 0.0 % 19 0
Test Date: 2026-03-29 15:29:34 Functions: 0.0 % 6 0

            Line data    Source code
       1              : #include "triggeralgs/ProtoDUNEBSMWindow/CompiledModelInterface.hpp"
       2              : 
       3              : #include <iostream>
       4              : 
       5              : namespace triggeralgs {
       6              : 
       7            0 : CompiledModelInterface::CompiledModelInterface(int nbatch) : num_batch(nbatch) {
       8            0 :   model_ptr = std::make_unique<TreelitePDHDModel>();
       9            0 : }
      10              : 
      11            0 : CompiledModelInterface::~CompiledModelInterface() {}
      12              :     
      13            0 : int CompiledModelInterface::GetNumFeatures() {
      14            0 :   return model_ptr->get_num_feature();
      15              : }
      16              : 
      17            0 : void CompiledModelInterface::ModelWarmUp(Entry *input) {
      18              :   // Warm the BDT up here
      19            0 :   float result[num_batch];
      20            0 :   for (int rid = 0; rid < num_batch; ++rid) {
      21            0 :     for (int i = 0; i < 100; i++) {
      22            0 :       model_ptr->predict(input, 0, result);
      23              :     }
      24              :   }
      25            0 : }
      26              : 
      27            0 : void CompiledModelInterface::Predict(Entry *input, float *result) {
      28            0 :   for (int rid = 0; rid < num_batch; ++rid) {
      29            0 :     model_ptr->predict(input, 0, result);
      30              :   }
      31            0 : }
      32              : 
      33            0 : bool CompiledModelInterface::Classify(const float *result, float &bdt_threshold) {
      34            0 :   for (uint64_t rid = 0; rid < num_batch; rid++) {
      35            0 :     if (result[rid] > bdt_threshold) {
      36              :       return true;
      37              :     }
      38              :   }
      39              :   return false;
      40              : }
      41              : 
      42              : 
      43              : } // namespace triggeralgs
        

Generated by: LCOV version 2.0-1