解决Aiida中提交WorkChain时的工作路径问题

Aiida 中有两种启动 WorkChain 的方式,一种是 run,还有一种是 submit。对于 run 提交方式,其运作原理是通过当前运行的终端对应的守护进程来处理提交的 WorkChain,运行时会阻塞终端,必须等到 WorkChain 结束后才能重新获得终端的控制权,在运行过程中,不能关闭终端,否则运行会直接被中断。在这种运行方式中,Python 解释器的搜索路径包含当前文件夹,并且 Python 的工作路径也是当前文件夹,因此不会出现太大问题,即使把运行代码和 WorkChain 的定义代码放在一个文件中,解释器同样可以找到 WorkChain 的定义。

这其中比较复杂的是第二种提交方式,即通过 submit 将 WorkChain 提交给后台守护进程(daemon),其 Python 解释器的搜索路径和工作路径与上一种方式都有所差别,下面将详细介绍。

1. Daemon

Aiida 中提供了一个在后台运行的守护进程,它可以异步处理任何提交的新进程(WorkChain)。启动守护进程时,后台将启动一个系统进程,该进程将无限期运行,直至停止。该守护进程负责启动并监控一个或多个守护进程 Worker 。每个守护进程 Worker 都是另一个系统进程,它连接到 RabbitMQ 以检索已提交的计算和 WorkChain,并将其运行至完成。如果守护进程 Worker 死亡,守护进程将自动恢复。当守护进程被要求停止时,它将向所有 Worker 发送信号以关闭它们,然后再关闭自己。

引自官网:https://aiida.readthedocs.io/projects/aiida-core/zh-cn/latest/topics/daemon.html

2. submit 方式提交时的搜索路径和工作路径

前面已经说到,通过 submit 方式提交 WorkChain 时,实际上是提交给 daemon,然后由 daemon 进行调度以及运行 WorkChain,因此在这种方式下 Python 解释器的搜索路径和和工作路径都变成了 deamon 相应的路径,而不再是当前运行终端下的路径,这一点非常重要。

2.1 向 daemon 的搜索路径中添加自定义路径

daemon 的搜索路径与 Python 的搜索路径一致,因此可以直接将需要添加的路径加入到 PYTHONPATH 变量中,具体参考我的这篇博客:
https://www.jun997.xyz/2024/05/19/36bf29e39e54.html#1-%E5%90%AF%E5%8A%A8%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84workchains

另外,值得一提的是,经过测试发现 daemon 不支持热加载,因此每次修改了 WorkChain 的定义之后,都需要重启一下 daemon,即是运行:

1
verdi daemon restart

2.2 daemon 的工作路径

经过测试发现,daemon 的默认工作路径是在终端运行 pg_ctl -D mylocal_db -l logfile start 命令时,当前终端的工作路径(pwd),而不是提交 WorkChain 时对应的路径。

2.3 将 daemon 的工作路径修改为提交作业时当前终端的工作路径

daemon 的工作路径即是 Python 解释器对应的工作路径,因此要修改 daemon 的工作路径就是修改 Python 解释器的工作路径,Python 中的 os 模块提供了这样的功能,具体用到的是 os.getcwd()os.chdir(path) 方法。

以下为示例代码:

WorkChain 的定义代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from aiida.orm import AbstractCode, Int, Str
from aiida.plugins.factories import CalculationFactory
from aiida.engine import WorkChain
import os

ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add')

class AddTwoIntsWorkChain(WorkChain):
"""WorkChain to add two ints, for testing and demonstration purposes."""
@classmethod
def define(cls, spec):
"""Specify inputs and outputs."""
super().define(spec)
spec.input('x', valid_type=Int)
spec.input('y', valid_type=Int)
spec.input('pwd', valid_type=Str)
spec.input('code', valid_type=AbstractCode)
spec.outline(
cls.init_dir,
cls.add,
cls.validate_result,
cls.result,
cls.write_result,
)
spec.output('result', valid_type=Int)
spec.exit_code(400, 'ERROR_NEGATIVE_NUMBER', message='The result is a negative number.')

def init_dir(self):
self.report('---init: Now at dir---:{}'.format(os.getcwd()))
os.chdir(str(self.inputs.pwd.value)) #### change work path
self.report('---init: Now at dir---:{}'.format(os.getcwd()))

def add(self):
"""Add two numbers using the `ArithmeticAddCalculation` calculation job plugin."""
inputs = {'x': self.inputs.x, 'y': self.inputs.y, 'code': self.inputs.code}
future = self.submit(ArithmeticAddCalculation, **inputs)
self.to_context(addition=future)

def validate_result(self):
"""Make sure the result is not negative."""
result = self.ctx.addition.outputs.sum
if result.value < 0:
return self.exit_codes.ERROR_NEGATIVE_NUMBER

def result(self):
"""Add the result to the outputs."""
self.out('result', self.ctx.addition.outputs.sum)

def write_result(self):
with open('sum_rlt.txt', 'w') as f:
f.write('result: {}\n'.format(int(self.ctx.addition.outputs.sum)))

提交 WorkChain 的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from testworkflow import AddTwoIntsWorkChain
from aiida.orm import Int, load_code, Str
from aiida import load_profile
from aiida.common.extendeddicts import AttributeDict
from aiida.engine import submit
import os


if __name__ == "__main__":
load_profile()

add_code = load_code('add@tutor')

inputs = AttributeDict()
inputs.x = Int(20)
inputs.y = Int(10)
print('Submit from: {}'.format(os.getcwd()))
inputs.pwd = Str(os.getcwd()) ### get submit path
inputs.code = add_code
workchain = AddTwoIntsWorkChain
submit(workchain, **inputs)

从上面代码可以看到,在提交脚本中,首先获取了当前提交的路径,然后将其作为参数传递给 WorkChain,WorkChain 中的 init_dir 方法将工作路径修改为参数中指定路径。