20 """Auxiliary functions for batch execution of MIRTK commands using SLURM."""    30 def submit(name, command=None, args=[], opts={}, script=None, tasks=0, deps=[],
    31            logdir=None, log=None, queue='long', threads=8, memory=8 * 1024,
    32            workdir=None, verbose=1):
    33     """Submit batch job to SLURM."""    35         raise ValueError(
"Must specify number of threads when executing as SLURM batch job!")
    38             raise ValueError(
"Keyword arguments 'command' and 'script' are mutually exclusive")
    39         shexec = 
"#!/bin/bash\nexec {0} <(cat <<END_OF_SCRIPT\n".format(sys.executable)
    40         shexec += script.format(**opts)
    41         shexec += 
"\nEND_OF_SCRIPT)"    43             shexec += 
" $SLURM_ARRAY_TASK_ID"    47         script = 
"#!/bin/bash\n"    48         script += 
"\"{0}\" -c 'import sys; import socket;".format(sys.executable)
    49         script += 
" sys.stdout.write(\"Host: \" + socket.gethostname() + \"\\n\\n\");"    50         script += 
" sys.path.insert(0, \"{0}\");".format(os.path.dirname(os.path.dirname(__file__)))
    51         script += 
" import mirtk; mirtk.check_call([\"{0}\"] + sys.argv[1:])".format(command 
if command 
else name)
    64                 if isinstance(arg, (list, tuple)):
    65                     arg = 
' '.join([str(x) 
for x 
in arg])
    69         script += 
" -threads {0}".format(threads)
    77         '--mem={0}M'.format(memory)
    80         argv.append(
'--array=0-{}'.format(tasks - 1))
    83             logdir = os.path.dirname(log)
    85             log = os.path.join(logdir, name)
    89         if not os.path.exists(logdir):
    91         argv.extend([
'-o', log, 
'-e', log])
    93         if isinstance(deps, int):
    95         deps = [str(dep) 
for dep 
in deps 
if dep > 0]
    97             argv.append(
'--dependency=afterok:' + 
',afterok:'.join(deps))
    99         argv.append(
'--workdir=' + os.path.abspath(workdir))
   100     proc = subprocess.Popen(argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
   101     (out, err) = proc.communicate(input=script.encode(
'utf-8'))
   102     if proc.returncode != 0:
   104     match = re.match(
'Submitted batch job ([0-9]+)', out)
   106         raise Exception(
"Failed to determine job ID from sbatch output:\n" + out)
   107     jobid = int(match.group(1))
   110             print(
"  Submitted batch {} (JobId={}, Tasks={})".format(name, jobid, tasks))
   112             print(
"  Submitted job {} (JobId={})".format(name, jobid))
   117 def wait(jobs, max_time=0, interval=60, verbose=0):
   118     """Wait for SLURM jobs to finish."""   119     if not isinstance(jobs, (list, tuple)):
   121     jobs = [job 
for job 
in jobs 
if job > 0]
   127     re_state = re.compile(
"JobState=([A-Z]+)")
   128     jobs = [str(job) 
for job 
in jobs]
   130     while num_wait > 0 
and (max_time <= 0 
or total_time < max_time):
   132         total_time += interval
   136         table = subprocess.check_output([
"sacct", 
"--brief", 
"--parsable", 
"--jobs={}".format(
','.join(jobs))])
   137         for line 
in table.splitlines():
   138             cols = line.split(
"|")
   140                 if cols[1] == 
"COMPLETED":
   142                 elif cols[1] 
not in (
"PENDING", 
"SUSPENDED", 
"RUNNING"):
   144             elif cols[0].endswith(
".batch"):
   145                 batch.append(cols[0][0:-6])
   151                 info = subprocess.check_output([
"scontrol", 
"show", 
"job", job])
   152                 for line 
in info.splitlines():
   153                     match = re_state.search(line)
   156                         if match.group(1) == 
"COMPLETED":
   158                         elif match.group(1) 
not in (
"PENDING", 
"SUSPENDED", 
"RUNNING"):
   169             except subprocess.CalledProcessError:
   171         num_done += len(done)
   172         num_fail += len(fail)
   173         num_wait = num_jobs - num_done - num_fail
   174         if verbose > 0 
and (num_wait <= 0 
or iterations % verbose == 0):
   175             sys.stdout.write(
"{:%Y-%b-%d %H:%M:%S}".format(datetime.now()))
   176             sys.stdout.write(
" WAIT {} job(s) running/suspended/pending".format(num_wait))
   178                 sys.stdout.write(
", {} failed".format(num_fail))
   179             sys.stdout.write(
"\n")
   182     if num_wait > 0 
and max_time > 0 
and total_time >= max_time:
   183         raise Exception(
"Exceeded maximum time waiting for jobs to complete!")