Try statement in Cython for cimport (for use with

2019-01-19 07:42发布

Is there a way to have the equivalent of the Python try statement in Cython for the cimport?

Something like that:

try:
    cimport something
except ImportError:
    pass

I would need this to write a Cython extension that can be compiled with or without mpi4py. This is very standard in compiled languages where the mpi commands can be put between #ifdef and #endif preprocessor directives. How can we obtain the same result in Cython?

I tried this but it does not work:

try:
    from mpi4py import MPI
    from mpi4py cimport MPI
    from mpi4py.mpi_c cimport *
except ImportError:
    rank = 0
    nb_proc = 1

# solve a incompatibility between openmpi and mpi4py versions
cdef extern from 'mpi-compat.h': pass

does_it_work = 'Not yet'

Actually it works well if mpi4py is correctly installed but if import mpi4py raises an ImportError, the Cython file does not compile and I get the error:

Error compiling Cython file:
------------------------------------------------------------
...

try:
    from mpi4py import MPI
    from mpi4py cimport MPI
   ^
------------------------------------------------------------

mod.pyx:4:4: 'mpi4py.pxd' not found

The file setup.py:

from setuptools import setup, Extension
from Cython.Distutils import build_ext

import os
here = os.path.abspath(os.path.dirname(__file__))

include_dirs = [here]

try:
    import mpi4py
except ImportError:
    pass
else:
    INCLUDE_MPI = '/usr/lib/openmpi/include'
    include_dirs.extend([
        INCLUDE_MPI,
        mpi4py.get_include()])

name = 'mod'
ext = Extension(
    name,
    include_dirs=include_dirs,
    sources=['mod.pyx'])

setup(name=name,
      cmdclass={"build_ext": build_ext},
      ext_modules=[ext])

2条回答
聊天终结者
2楼-- · 2019-01-19 08:25

Using a try-catch block in this way is something you won't be able to do. The extension module you are making must be statically compiled and linked against the things it uses cimport to load at the C-level. A try-catch block is something that will be executed when the module is imported, not when it is compiled.

On the other hand, in theory, you should be able to get the effect you're looking for using Cython's support for conditional compilation. In your setup.py file you can check to see if the needed modules are defined and then define environment variables to be passed to the Cython compiler that, in turn, depend on whether or not the needed modules are present.

There's an example of how to do this in one of Cython's tests. There they pass a dictionary containing the desired environment variables to the constructor for Cython's Extension class as the keyword argument pyrex_compile_time_env, which has been renamed to cython_compile_time_env, and for Cython.Build.Dependencies.cythonize is called compile_time_env).

查看更多
欢心
3楼-- · 2019-01-19 08:32

Thank you for your very useful answer @IanH. I include an example to show what it gives.

The file setup.py:

from setuptools import setup
from Cython.Distutils.extension import Extension
from Cython.Distutils import build_ext

import os
here = os.path.abspath(os.path.dirname(__file__))

import numpy as np
include_dirs = [here, np.get_include()]

try:
    import mpi4py
except ImportError:
    MPI4PY = False
else:
    MPI4PY = True
    INCLUDE_MPI = '/usr/lib/openmpi/include'
    include_dirs.extend([
        INCLUDE_MPI,
        mpi4py.get_include()])

name = 'mod'
ext = Extension(
    name,
    include_dirs=include_dirs,
    cython_compile_time_env={'MPI4PY': MPI4PY},
    sources=['mod.pyx'])

setup(name=name,
      cmdclass={"build_ext": build_ext},
      ext_modules=[ext])

if not MPI4PY:
    print('Warning: since importing mpi4py raises an ImportError,\n'
          '         the extensions are compiled without mpi and \n'
          '         will work only in sequencial.')

And the file mod.pyx, with a little bit of real mpi commands:

import numpy as np
cimport numpy as np

try:
    from mpi4py import MPI
except ImportError:
    nb_proc = 1
    rank = 0
else:
    comm = MPI.COMM_WORLD
    nb_proc = comm.size
    rank = comm.Get_rank()

IF MPI4PY:
    from mpi4py cimport MPI
    from mpi4py.mpi_c cimport *

    # solve an incompatibility between openmpi and mpi4py versions
    cdef extern from 'mpi-compat.h': pass

    print('mpi4py ok')
ELSE:
    print('no mpi4py')

n = 8
if n % nb_proc != 0:
    raise ValueError('The number of processes is incorrect.')

if rank == 0:
    data_seq = np.ones([n], dtype=np.int32)
    s_seq = data_seq.sum()
else:
    data_seq = np.zeros([n], dtype=np.int32)

if nb_proc > 1:
    data_local = np.zeros([n/nb_proc], dtype=np.int32)
    comm.Scatter(data_seq, data_local, root=0)
else:
    data_local = data_seq

s = data_local.sum()
if nb_proc > 1:
    s = comm.allreduce(s, op=MPI.SUM)

if rank == 0:
    print('s: {}; s_seq: {}'.format(s, s_seq))
    assert s == s_seq

Build with python setup.py build_ext --inplace and test with python -c "import mod" and mpirun -np 4 python -c "import mod". If mpi4py is not installed, one can still build the module and use it in sequential.

查看更多
登录 后发表回答