torch_em.util.submit_slurm

 1import os
 2import sys
 3import inspect
 4import subprocess
 5from datetime import datetime
 6
 7# two days in minutes
 8TWO_DAYS = 2 * 24 * 60
 9
10
11def write_slurm_template(script, out_path, env_name,
12                         n_threads, gpu_type, n_gpus,
13                         mem_limit, time_limit, qos,
14                         mail_address, exclude_nodes):
15    slurm_template = ("#!/bin/bash\n"
16                      "#SBATCH -A kreshuk\n"
17                      "#SBATCH -N 1\n"
18                      f"#SBATCH -c {n_threads}\n"
19                      f"#SBATCH --mem {mem_limit}\n"
20                      f"#SBATCH -t {time_limit}\n"
21                      f"#SBATCH --qos={qos}\n"
22                      "#SBATCH -p gpu\n"
23                      f"#SBATCH -C gpu={gpu_type}\n"
24                      f"#SBATCH --gres=gpu:{n_gpus}\n")
25    if mail_address is not None:
26        slurm_template += ("#SBATCH --mail-type=FAIL,BEGIN,END\n"
27                           f"#SBATCH --mail-user={mail_address}\n")
28    if exclude_nodes is not None:
29        slurm_template += "#SBATCH --exclude=%s\n" % ",".join(exclude_nodes)
30
31    slurm_template += ("\n"
32                       "module purge \n"
33                       "module load GCC \n"
34                       "source activate {env_name}\n"
35                       "\n"
36                       "export TRAIN_ON_CLUSTER=1\n"  # we set this env variable, so that script know we"re on slurm
37                       f"python {script} $@ \n")
38    with open(out_path, "w") as f:
39        f.write(slurm_template)
40
41
42def submit_slurm(script, input_, n_threads=7, n_gpus=1,
43                 gpu_type="2080Ti", mem_limit="64G",
44                 time_limit=TWO_DAYS, qos="normal",
45                 env_name=None, mail_address=None,
46                 exclude_nodes=None):
47    """ Submit python script that needs gpus with given inputs on a slurm node.
48    """
49
50    tmp_folder = os.path.expanduser("~/.torch_em/submission")
51    os.makedirs(tmp_folder, exist_ok=True)
52
53    print("Submitting training script %s to cluster" % script)
54    print("with arguments %s" % " ".join(input_))
55
56    script_name = os.path.split(script)[1]
57    dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
58    tmp_name = os.path.splitext(script_name)[0] + dt
59    batch_script = os.path.join(tmp_folder, "%s.sh" % tmp_name)
60    log = os.path.join(tmp_folder, "%s.log" % tmp_name)
61    err = os.path.join(tmp_folder, "%s.err" % tmp_name)
62
63    if env_name is None:
64        env_name = os.environ.get("CONDA_DEFAULT_ENV", None)
65        if env_name is None:
66            raise RuntimeError("Could not find conda")
67
68    print("Batch script saved at", batch_script)
69    print("Log will be written to %s, error log to %s" % (log, err))
70    write_slurm_template(script, batch_script, env_name,
71                         int(n_threads), gpu_type, int(n_gpus),
72                         mem_limit, int(time_limit), qos, mail_address,
73                         exclude_nodes=exclude_nodes)
74
75    cmd = ["sbatch", "-o", log, "-e", err, "-J", script_name, batch_script]
76    cmd.extend(input_)
77    subprocess.run(cmd)
78
79
80def scrape_kwargs(input_):
81    params = inspect.signature(submit_slurm).parameters
82    kwarg_names = [name for name in params if params[name].default != inspect._empty]
83    kwarg_names.extend([f"-{name}" for name in kwarg_names])
84    kwarg_names.extend([f"--{name}" for name in kwarg_names])
85    kwarg_positions = [i for i, inp in enumerate(input_) if inp in kwarg_names]
86    kwargs = {input_[i].lstrip("-"): input_[i + 1] for i in kwarg_positions}
87    kwarg_positions += [i + 1 for i in kwarg_positions]
88    input_ = [inp for i, inp in enumerate(input_) if i not in kwarg_positions]
89    return input_, kwargs
90
91
92def main():
93    script = os.path.realpath(os.path.abspath(sys.argv[1]))
94    input_ = sys.argv[2:]
95    # scrape the additional arguments (n_threads, mem_limit, etc. from the input)
96    input_, kwargs = scrape_kwargs(input_)
97    submit_slurm(script, input_, **kwargs)
TWO_DAYS = 2880
def write_slurm_template( script, out_path, env_name, n_threads, gpu_type, n_gpus, mem_limit, time_limit, qos, mail_address, exclude_nodes):
12def write_slurm_template(script, out_path, env_name,
13                         n_threads, gpu_type, n_gpus,
14                         mem_limit, time_limit, qos,
15                         mail_address, exclude_nodes):
16    slurm_template = ("#!/bin/bash\n"
17                      "#SBATCH -A kreshuk\n"
18                      "#SBATCH -N 1\n"
19                      f"#SBATCH -c {n_threads}\n"
20                      f"#SBATCH --mem {mem_limit}\n"
21                      f"#SBATCH -t {time_limit}\n"
22                      f"#SBATCH --qos={qos}\n"
23                      "#SBATCH -p gpu\n"
24                      f"#SBATCH -C gpu={gpu_type}\n"
25                      f"#SBATCH --gres=gpu:{n_gpus}\n")
26    if mail_address is not None:
27        slurm_template += ("#SBATCH --mail-type=FAIL,BEGIN,END\n"
28                           f"#SBATCH --mail-user={mail_address}\n")
29    if exclude_nodes is not None:
30        slurm_template += "#SBATCH --exclude=%s\n" % ",".join(exclude_nodes)
31
32    slurm_template += ("\n"
33                       "module purge \n"
34                       "module load GCC \n"
35                       "source activate {env_name}\n"
36                       "\n"
37                       "export TRAIN_ON_CLUSTER=1\n"  # we set this env variable, so that script know we"re on slurm
38                       f"python {script} $@ \n")
39    with open(out_path, "w") as f:
40        f.write(slurm_template)
def submit_slurm( script, input_, n_threads=7, n_gpus=1, gpu_type='2080Ti', mem_limit='64G', time_limit=2880, qos='normal', env_name=None, mail_address=None, exclude_nodes=None):
43def submit_slurm(script, input_, n_threads=7, n_gpus=1,
44                 gpu_type="2080Ti", mem_limit="64G",
45                 time_limit=TWO_DAYS, qos="normal",
46                 env_name=None, mail_address=None,
47                 exclude_nodes=None):
48    """ Submit python script that needs gpus with given inputs on a slurm node.
49    """
50
51    tmp_folder = os.path.expanduser("~/.torch_em/submission")
52    os.makedirs(tmp_folder, exist_ok=True)
53
54    print("Submitting training script %s to cluster" % script)
55    print("with arguments %s" % " ".join(input_))
56
57    script_name = os.path.split(script)[1]
58    dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
59    tmp_name = os.path.splitext(script_name)[0] + dt
60    batch_script = os.path.join(tmp_folder, "%s.sh" % tmp_name)
61    log = os.path.join(tmp_folder, "%s.log" % tmp_name)
62    err = os.path.join(tmp_folder, "%s.err" % tmp_name)
63
64    if env_name is None:
65        env_name = os.environ.get("CONDA_DEFAULT_ENV", None)
66        if env_name is None:
67            raise RuntimeError("Could not find conda")
68
69    print("Batch script saved at", batch_script)
70    print("Log will be written to %s, error log to %s" % (log, err))
71    write_slurm_template(script, batch_script, env_name,
72                         int(n_threads), gpu_type, int(n_gpus),
73                         mem_limit, int(time_limit), qos, mail_address,
74                         exclude_nodes=exclude_nodes)
75
76    cmd = ["sbatch", "-o", log, "-e", err, "-J", script_name, batch_script]
77    cmd.extend(input_)
78    subprocess.run(cmd)

Submit python script that needs gpus with given inputs on a slurm node.

def scrape_kwargs(input_):
81def scrape_kwargs(input_):
82    params = inspect.signature(submit_slurm).parameters
83    kwarg_names = [name for name in params if params[name].default != inspect._empty]
84    kwarg_names.extend([f"-{name}" for name in kwarg_names])
85    kwarg_names.extend([f"--{name}" for name in kwarg_names])
86    kwarg_positions = [i for i, inp in enumerate(input_) if inp in kwarg_names]
87    kwargs = {input_[i].lstrip("-"): input_[i + 1] for i in kwarg_positions}
88    kwarg_positions += [i + 1 for i in kwarg_positions]
89    input_ = [inp for i, inp in enumerate(input_) if i not in kwarg_positions]
90    return input_, kwargs
def main():
93def main():
94    script = os.path.realpath(os.path.abspath(sys.argv[1]))
95    input_ = sys.argv[2:]
96    # scrape the additional arguments (n_threads, mem_limit, etc. from the input)
97    input_, kwargs = scrape_kwargs(input_)
98    submit_slurm(script, input_, **kwargs)