batch/slurm.py
1 ##############################################################################
2 # Medical Image Registration ToolKit (MIRTK)
3 #
4 # Copyright 2017 Imperial College London
5 # Copyright 2017 Andreas Schuh
6 #
7 # Licensed under the Apache License, Version 2.0 (the "License");
8 # you may not use this file except in compliance with the License.
9 # You may obtain a copy of the License at
10 #
11 # http://www.apache.org/licenses/LICENSE-2.0
12 #
13 # Unless required by applicable law or agreed to in writing, software
14 # distributed under the License is distributed on an "AS IS" BASIS,
15 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 # See the License for the specific language governing permissions and
17 # limitations under the License.
18 ##############################################################################
19 
20 """Auxiliary functions for batch execution of MIRTK commands using SLURM."""
21 
22 import re
23 import os
24 import sys
25 import subprocess
26 import time
27 
28 
29 # ----------------------------------------------------------------------------
30 def submit(name, command=None, args=[], opts={}, script=None, tasks=0, deps=[],
31  logdir=None, log=None, queue='long', threads=8, memory=8 * 1024,
32  workdir=None, verbose=1):
33  """Submit batch job to SLURM."""
34  if threads <= 0:
35  raise ValueError("Must specify number of threads when executing as SLURM batch job!")
36  if script:
37  if command:
38  raise ValueError("Keyword arguments 'command' and 'script' are mutually exclusive")
39  shexec = "#!/bin/bash\nexec {0} <(cat <<END_OF_SCRIPT\n".format(sys.executable)
40  shexec += script.format(**opts)
41  shexec += "\nEND_OF_SCRIPT)"
42  if tasks > 0:
43  shexec += " $SLURM_ARRAY_TASK_ID"
44  shexec += "\n"
45  script = shexec
46  else:
47  script = "#!/bin/bash\n"
48  script += "\"{0}\" -c 'import sys; import socket;".format(sys.executable)
49  script += " sys.stdout.write(\"Host: \" + socket.gethostname() + \"\\n\\n\");"
50  script += " sys.path.insert(0, \"{0}\");".format(os.path.dirname(os.path.dirname(__file__)))
51  script += " import mirtk; mirtk.check_call([\"{0}\"] + sys.argv[1:])".format(command if command else name)
52  script += "'"
53  for arg in args:
54  arg = str(arg)
55  if ' ' in arg:
56  arg = '"' + arg + '"'
57  script += ' ' + arg
58  for opt in opts:
59  arg = opts[opt]
60  if opt[0] != '-':
61  opt = '-' + opt
62  script += ' ' + opt
63  if arg is not None:
64  if isinstance(arg, (list, tuple)):
65  arg = ' '.join([str(x) for x in arg])
66  else:
67  arg = str(arg)
68  script += ' ' + arg
69  script += " -threads {0}".format(threads)
70  script += "\n"
71  argv = [
72  'sbatch',
73  '-J', name,
74  '-n', '1',
75  '-c', str(threads),
76  '-p', queue,
77  '--mem={0}M'.format(memory)
78  ]
79  if tasks > 0:
80  argv.append('--array=0-{}'.format(tasks - 1))
81  if logdir or log:
82  if not logdir:
83  logdir = os.path.dirname(log)
84  elif not log:
85  log = os.path.join(logdir, name)
86  if tasks > 0:
87  log += "_%A.%a.log"
88  log += "_%j.log"
89  if not os.path.exists(logdir):
90  os.makedirs(logdir)
91  argv.extend(['-o', log, '-e', log])
92  if deps:
93  if isinstance(deps, int):
94  deps = [deps]
95  deps = [str(dep) for dep in deps if dep > 0]
96  if deps:
97  argv.append('--dependency=afterok:' + ',afterok:'.join(deps))
98  if workdir:
99  argv.append('--workdir=' + os.path.abspath(workdir))
100  proc = subprocess.Popen(argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
101  (out, err) = proc.communicate(input=script.encode('utf-8'))
102  if proc.returncode != 0:
103  raise Exception(err)
104  match = re.match('Submitted batch job ([0-9]+)', out)
105  if not match:
106  raise Exception("Failed to determine job ID from sbatch output:\n" + out)
107  jobid = int(match.group(1))
108  if verbose > 0:
109  if tasks > 0:
110  print(" Submitted batch {} (JobId={}, Tasks={})".format(name, jobid, tasks))
111  else:
112  print(" Submitted job {} (JobId={})".format(name, jobid))
113  return jobid
114 
115 
116 # ----------------------------------------------------------------------------
117 def wait(jobs, max_time=0, interval=60, verbose=0):
118  """Wait for SLURM jobs to finish."""
119  if not isinstance(jobs, (list, tuple)):
120  jobs = [jobs]
121  jobs = [job for job in jobs if job > 0]
122  num_wait = len(jobs)
123  if num_wait == 0:
124  return True
125  num_fail = 0
126  total_time = 0
127  re_state = re.compile("JobState=([A-Z]+)")
128  jobs = [str(job) for job in jobs]
129  iterations = 0
130  while num_wait > 0 and (max_time <= 0 or total_time < max_time):
131  time.sleep(interval)
132  total_time += interval
133  done = []
134  fail = []
135  batch = []
136  table = subprocess.check_output(["sacct", "--brief", "--parsable", "--jobs={}".format(','.join(jobs))])
137  for line in table.splitlines():
138  cols = line.split("|")
139  if cols[0] in jobs:
140  if cols[1] == "COMPLETED":
141  done.append(cols[0])
142  elif cols[1] not in ("PENDING", "SUSPENDED", "RUNNING"):
143  fail.append(cols[0])
144  elif cols[0].endswith(".batch"):
145  batch.append(cols[0][0:-6])
146  num_jobs = len(jobs)
147  num_done = 0
148  num_fail = 0
149  for job in batch:
150  try:
151  info = subprocess.check_output(["scontrol", "show", "job", job])
152  for line in info.splitlines():
153  match = re_state.search(line)
154  if match:
155  num_jobs += 1
156  if match.group(1) == "COMPLETED":
157  num_done += 1
158  elif match.group(1) not in ("PENDING", "SUSPENDED", "RUNNING"):
159  num_fail += 1
160  num_jobs -= 1
161  try:
162  done.remove(job)
163  except ValueError:
164  pass
165  try:
166  fail.remove(job)
167  except ValueError:
168  pass
169  except subprocess.CalledProcessError:
170  pass # scontrol forgets about no longer queued/running jobs
171  num_done += len(done)
172  num_fail += len(fail)
173  num_wait = num_jobs - num_done - num_fail
174  if verbose > 0 and (num_wait <= 0 or iterations % verbose == 0):
175  sys.stdout.write("{:%Y-%b-%d %H:%M:%S}".format(datetime.now()))
176  sys.stdout.write(" WAIT {} job(s) running/suspended/pending".format(num_wait))
177  if num_fail > 0:
178  sys.stdout.write(", {} failed".format(num_fail))
179  sys.stdout.write("\n")
180  sys.stdout.flush()
181  ++iterations
182  if num_wait > 0 and max_time > 0 and total_time >= max_time:
183  raise Exception("Exceeded maximum time waiting for jobs to complete!")
184  if total_time > 0:
185  time.sleep(10) # wait a bit for files to be available from all NFS clients
186  return num_fail == 0