TensorRT  7.2.1.6
NVIDIA TensorRT
Looking for a C++ dev who knows TensorRT?
I'm looking for work. Hire me!
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Pages
train Namespace Reference

Functions

def parse_args (parser)
 
def reduce_tensor (tensor, num_gpus)
 
def init_distributed (args, world_size, rank, group_name)
 
def save_checkpoint (model, optimizer, epoch, config, amp_run, output_dir, model_name, local_rank, world_size)
 
def get_last_checkpoint_filename (output_dir, model_name)
 
def load_checkpoint (model, optimizer, epoch, config, amp_run, filepath, local_rank)
 
def evaluating (model)
 
def validate (model, criterion, valset, epoch, batch_iter, batch_size, world_size, collate_fn, distributed_run, rank, batch_to_gpu)
 
def adjust_learning_rate (iteration, epoch, optimizer, learning_rate, anneal_steps, anneal_factor, rank)
 
def main ()
 

Function Documentation

◆ parse_args()

def train.parse_args (   parser)
Parse commandline arguments.
Here is the caller graph for this function:

◆ reduce_tensor()

def train.reduce_tensor (   tensor,
  num_gpus 
)
Here is the caller graph for this function:

◆ init_distributed()

def train.init_distributed (   args,
  world_size,
  rank,
  group_name 
)
Here is the caller graph for this function:

◆ save_checkpoint()

def train.save_checkpoint (   model,
  optimizer,
  epoch,
  config,
  amp_run,
  output_dir,
  model_name,
  local_rank,
  world_size 
)
Here is the caller graph for this function:

◆ get_last_checkpoint_filename()

def train.get_last_checkpoint_filename (   output_dir,
  model_name 
)
Here is the caller graph for this function:

◆ load_checkpoint()

def train.load_checkpoint (   model,
  optimizer,
  epoch,
  config,
  amp_run,
  filepath,
  local_rank 
)
Here is the caller graph for this function:

◆ evaluating()

def train.evaluating (   model)
Temporarily switch to evaluation mode.
Here is the caller graph for this function:

◆ validate()

def train.validate (   model,
  criterion,
  valset,
  epoch,
  batch_iter,
  batch_size,
  world_size,
  collate_fn,
  distributed_run,
  rank,
  batch_to_gpu 
)
Handles all the validation scoring and printing
Here is the call graph for this function:
Here is the caller graph for this function:

◆ adjust_learning_rate()

def train.adjust_learning_rate (   iteration,
  epoch,
  optimizer,
  learning_rate,
  anneal_steps,
  anneal_factor,
  rank 
)
Here is the caller graph for this function:

◆ main()

def train.main ( )
Here is the call graph for this function: