-
Notifications
You must be signed in to change notification settings - Fork 168
/
Copy pathadd_backend.py
46 lines (39 loc) · 1.37 KB
/
add_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
from pathlib import Path
import jinja2
def main():
parser = argparse.ArgumentParser(
description="This script generates boilerplate code for a new backend"
)
parser.add_argument(
"-n",
"--name",
help=(
"The backend name in CamelCase, e.g. AWS, Runpod, VastAI."
" It'll be used for naming backend classes, models, etc."
),
required=True,
)
args = parser.parse_args()
generate_backend_code(args.name)
def generate_backend_code(backend_name: str):
template_dir_path = Path(__file__).parent.parent.joinpath(
"src/dstack/_internal/core/backends/template"
)
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(
searchpath=template_dir_path,
),
keep_trailing_newline=True,
)
backend_dir_path = Path(__file__).parent.parent.joinpath(
f"src/dstack/_internal/core/backends/{backend_name.lower()}"
)
backend_dir_path.mkdir(exist_ok=True)
for filename in ["backend.py", "compute.py", "configurator.py", "models.py"]:
template = env.get_template(f"{filename}.jinja")
with open(backend_dir_path.joinpath(filename), "w+") as f:
f.write(template.render({"backend_name": backend_name}))
backend_dir_path.joinpath("__init__.py").write_text("")
if __name__ == "__main__":
main()