torch_em.util.submit_slurm

@private

  1"""@private
  2"""
  3# This functionality is currently hard-coded to the EMBL slurm cluster.
  4# To enable it for other slurm clusters, several configurations, like accounting and partition names,
  5# would have to be exposed.
  6
  7import os
  8import sys
  9import inspect
 10import subprocess
 11from datetime import datetime
 12
 13# two days in minutes
 14TWO_DAYS = 2 * 24 * 60
 15
 16
 17def write_slurm_template(script, out_path, env_name,
 18                         n_threads, gpu_type, n_gpus,
 19                         mem_limit, time_limit, qos,
 20                         mail_address, exclude_nodes):
 21    slurm_template = ("#!/bin/bash\n"
 22                      "#SBATCH -A kreshuk\n"
 23                      "#SBATCH -N 1\n"
 24                      f"#SBATCH -c {n_threads}\n"
 25                      f"#SBATCH --mem {mem_limit}\n"
 26                      f"#SBATCH -t {time_limit}\n"
 27                      f"#SBATCH --qos={qos}\n"
 28                      "#SBATCH -p gpu\n"
 29                      f"#SBATCH -C gpu={gpu_type}\n"
 30                      f"#SBATCH --gres=gpu:{n_gpus}\n")
 31    if mail_address is not None:
 32        slurm_template += ("#SBATCH --mail-type=FAIL,BEGIN,END\n"
 33                           f"#SBATCH --mail-user={mail_address}\n")
 34    if exclude_nodes is not None:
 35        slurm_template += "#SBATCH --exclude=%s\n" % ",".join(exclude_nodes)
 36
 37    slurm_template += ("\n"
 38                       "module purge \n"
 39                       "module load GCC \n"
 40                       "source activate {env_name}\n"
 41                       "\n"
 42                       "export TRAIN_ON_CLUSTER=1\n"  # we set this env variable, so that script know we"re on slurm
 43                       f"python {script} $@ \n")
 44    with open(out_path, "w") as f:
 45        f.write(slurm_template)
 46
 47
 48def submit_slurm(script, input_, n_threads=7, n_gpus=1,
 49                 gpu_type="2080Ti", mem_limit="64G",
 50                 time_limit=TWO_DAYS, qos="normal",
 51                 env_name=None, mail_address=None,
 52                 exclude_nodes=None):
 53    """ Submit python script that needs gpus with given inputs on a slurm node.
 54    """
 55
 56    tmp_folder = os.path.expanduser("~/.torch_em/submission")
 57    os.makedirs(tmp_folder, exist_ok=True)
 58
 59    print("Submitting training script %s to cluster" % script)
 60    print("with arguments %s" % " ".join(input_))
 61
 62    script_name = os.path.split(script)[1]
 63    dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
 64    tmp_name = os.path.splitext(script_name)[0] + dt
 65    batch_script = os.path.join(tmp_folder, "%s.sh" % tmp_name)
 66    log = os.path.join(tmp_folder, "%s.log" % tmp_name)
 67    err = os.path.join(tmp_folder, "%s.err" % tmp_name)
 68
 69    if env_name is None:
 70        env_name = os.environ.get("CONDA_DEFAULT_ENV", None)
 71        if env_name is None:
 72            raise RuntimeError("Could not find conda")
 73
 74    print("Batch script saved at", batch_script)
 75    print("Log will be written to %s, error log to %s" % (log, err))
 76    write_slurm_template(script, batch_script, env_name,
 77                         int(n_threads), gpu_type, int(n_gpus),
 78                         mem_limit, int(time_limit), qos, mail_address,
 79                         exclude_nodes=exclude_nodes)
 80
 81    cmd = ["sbatch", "-o", log, "-e", err, "-J", script_name, batch_script]
 82    cmd.extend(input_)
 83    subprocess.run(cmd)
 84
 85
 86def scrape_kwargs(input_):
 87    params = inspect.signature(submit_slurm).parameters
 88    kwarg_names = [name for name in params if params[name].default != inspect._empty]
 89    kwarg_names.extend([f"-{name}" for name in kwarg_names])
 90    kwarg_names.extend([f"--{name}" for name in kwarg_names])
 91    kwarg_positions = [i for i, inp in enumerate(input_) if inp in kwarg_names]
 92    kwargs = {input_[i].lstrip("-"): input_[i + 1] for i in kwarg_positions}
 93    kwarg_positions += [i + 1 for i in kwarg_positions]
 94    input_ = [inp for i, inp in enumerate(input_) if i not in kwarg_positions]
 95    return input_, kwargs
 96
 97
 98def main():
 99    script = os.path.realpath(os.path.abspath(sys.argv[1]))
100    input_ = sys.argv[2:]
101    # scrape the additional arguments (n_threads, mem_limit, etc. from the input)
102    input_, kwargs = scrape_kwargs(input_)
103    submit_slurm(script, input_, **kwargs)
TWO_DAYS = 2880
def write_slurm_template( script, out_path, env_name, n_threads, gpu_type, n_gpus, mem_limit, time_limit, qos, mail_address, exclude_nodes):
18def write_slurm_template(script, out_path, env_name,
19                         n_threads, gpu_type, n_gpus,
20                         mem_limit, time_limit, qos,
21                         mail_address, exclude_nodes):
22    slurm_template = ("#!/bin/bash\n"
23                      "#SBATCH -A kreshuk\n"
24                      "#SBATCH -N 1\n"
25                      f"#SBATCH -c {n_threads}\n"
26                      f"#SBATCH --mem {mem_limit}\n"
27                      f"#SBATCH -t {time_limit}\n"
28                      f"#SBATCH --qos={qos}\n"
29                      "#SBATCH -p gpu\n"
30                      f"#SBATCH -C gpu={gpu_type}\n"
31                      f"#SBATCH --gres=gpu:{n_gpus}\n")
32    if mail_address is not None:
33        slurm_template += ("#SBATCH --mail-type=FAIL,BEGIN,END\n"
34                           f"#SBATCH --mail-user={mail_address}\n")
35    if exclude_nodes is not None:
36        slurm_template += "#SBATCH --exclude=%s\n" % ",".join(exclude_nodes)
37
38    slurm_template += ("\n"
39                       "module purge \n"
40                       "module load GCC \n"
41                       "source activate {env_name}\n"
42                       "\n"
43                       "export TRAIN_ON_CLUSTER=1\n"  # we set this env variable, so that script know we"re on slurm
44                       f"python {script} $@ \n")
45    with open(out_path, "w") as f:
46        f.write(slurm_template)
def submit_slurm( script, input_, n_threads=7, n_gpus=1, gpu_type='2080Ti', mem_limit='64G', time_limit=2880, qos='normal', env_name=None, mail_address=None, exclude_nodes=None):
49def submit_slurm(script, input_, n_threads=7, n_gpus=1,
50                 gpu_type="2080Ti", mem_limit="64G",
51                 time_limit=TWO_DAYS, qos="normal",
52                 env_name=None, mail_address=None,
53                 exclude_nodes=None):
54    """ Submit python script that needs gpus with given inputs on a slurm node.
55    """
56
57    tmp_folder = os.path.expanduser("~/.torch_em/submission")
58    os.makedirs(tmp_folder, exist_ok=True)
59
60    print("Submitting training script %s to cluster" % script)
61    print("with arguments %s" % " ".join(input_))
62
63    script_name = os.path.split(script)[1]
64    dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
65    tmp_name = os.path.splitext(script_name)[0] + dt
66    batch_script = os.path.join(tmp_folder, "%s.sh" % tmp_name)
67    log = os.path.join(tmp_folder, "%s.log" % tmp_name)
68    err = os.path.join(tmp_folder, "%s.err" % tmp_name)
69
70    if env_name is None:
71        env_name = os.environ.get("CONDA_DEFAULT_ENV", None)
72        if env_name is None:
73            raise RuntimeError("Could not find conda")
74
75    print("Batch script saved at", batch_script)
76    print("Log will be written to %s, error log to %s" % (log, err))
77    write_slurm_template(script, batch_script, env_name,
78                         int(n_threads), gpu_type, int(n_gpus),
79                         mem_limit, int(time_limit), qos, mail_address,
80                         exclude_nodes=exclude_nodes)
81
82    cmd = ["sbatch", "-o", log, "-e", err, "-J", script_name, batch_script]
83    cmd.extend(input_)
84    subprocess.run(cmd)

Submit python script that needs gpus with given inputs on a slurm node.

def scrape_kwargs(input_):
87def scrape_kwargs(input_):
88    params = inspect.signature(submit_slurm).parameters
89    kwarg_names = [name for name in params if params[name].default != inspect._empty]
90    kwarg_names.extend([f"-{name}" for name in kwarg_names])
91    kwarg_names.extend([f"--{name}" for name in kwarg_names])
92    kwarg_positions = [i for i, inp in enumerate(input_) if inp in kwarg_names]
93    kwargs = {input_[i].lstrip("-"): input_[i + 1] for i in kwarg_positions}
94    kwarg_positions += [i + 1 for i in kwarg_positions]
95    input_ = [inp for i, inp in enumerate(input_) if i not in kwarg_positions]
96    return input_, kwargs
def main():
 99def main():
100    script = os.path.realpath(os.path.abspath(sys.argv[1]))
101    input_ = sys.argv[2:]
102    # scrape the additional arguments (n_threads, mem_limit, etc. from the input)
103    input_, kwargs = scrape_kwargs(input_)
104    submit_slurm(script, input_, **kwargs)