Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions sumpy/expansion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from sumpy.expansion.diff_op import MultiIndex
from sumpy.expansion.local import LocalExpansionBase
from sumpy.expansion.multipole import MultipoleExpansionBase
from sumpy.kernel import Kernel, KernelArgument
from sumpy.kernel import KernelArgument, ScalarKernel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +99,7 @@ class ExpansionBase(ABC):
.. automethod:: __ne__
"""

kernel: Kernel
kernel: ScalarKernel
order: int
use_rscale: bool = field(kw_only=True, default=True)

Expand Down Expand Up @@ -152,7 +152,7 @@ def get_coefficient_identifiers(self) -> Sequence[MultiIndex]:

@abstractmethod
def coefficients_from_source(self,
kernel: Kernel,
kernel: ScalarKernel,
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand All @@ -171,7 +171,7 @@ def coefficients_from_source(self,
"""

def coefficients_from_source_vec(self,
kernels: Sequence[Kernel],
kernels: Sequence[ScalarKernel],
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand All @@ -198,7 +198,7 @@ def coefficients_from_source_vec(self,
return result

def loopy_expansion_formation(self,
kernels: Sequence[Kernel],
kernels: Sequence[ScalarKernel],
strength_usage: Sequence[int],
nstrengths: int
) -> lp.TranslationUnit:
Expand All @@ -212,7 +212,7 @@ def loopy_expansion_formation(self,

@abstractmethod
def evaluate(self,
kernel: Kernel,
kernel: ScalarKernel,
coeffs: Sequence[sym.Expr],
bvec: sym.Matrix,
rscale: sym.Expr,
Expand All @@ -224,7 +224,7 @@ def evaluate(self,
in *coeffs*.
"""

def loopy_evaluator(self, kernels: Sequence[Kernel]) -> lp.TranslationUnit:
def loopy_evaluator(self, kernels: Sequence[ScalarKernel]) -> lp.TranslationUnit:
"""
:returns: a :mod:`loopy` kernel that returns the evaluated
target transforms of the potential given by *kernels*.
Expand All @@ -236,7 +236,7 @@ def loopy_evaluator(self, kernels: Sequence[Kernel]) -> lp.TranslationUnit:

# {{{ copy

def with_kernel(self, kernel: Kernel) -> ExpansionBase:
def with_kernel(self, kernel: ScalarKernel) -> ExpansionBase:
return replace(self, kernel=kernel)

def copy(self, **kwargs: Any) -> Self:
Expand Down Expand Up @@ -566,7 +566,7 @@ class LinearPDEBasedExpansionTermsWrangler(ExpansionTermsWrangler):
.. automethod:: __init__
"""

knl: Kernel
knl: ScalarKernel

@override
def get_coefficient_identifiers(self) -> Sequence[MultiIndex]:
Expand Down Expand Up @@ -973,7 +973,7 @@ class MultipoleExpansionFactory(Protocol):
.. automethod:: __call__
"""
def __call__(self,
kernel: Kernel,
kernel: ScalarKernel,
order: int,
*, use_rscale: bool = True
) -> MultipoleExpansionBase:
Expand All @@ -987,7 +987,7 @@ class LocalExpansionFactory(Protocol):
.. automethod:: __call__
"""
def __call__(self,
kernel: Kernel,
kernel: ScalarKernel,
order: int,
*, use_rscale: bool = True
) -> LocalExpansionBase:
Expand All @@ -1002,15 +1002,15 @@ class ExpansionFactoryBase(ABC):

@abstractmethod
def get_local_expansion_class(self,
base_kernel: Kernel, /
base_kernel: ScalarKernel, /
) -> LocalExpansionFactory:
"""
:returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
"""

@abstractmethod
def get_multipole_expansion_class(self,
base_kernel: Kernel, /
base_kernel: ScalarKernel, /
) -> MultipoleExpansionFactory:
"""
:returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
Expand All @@ -1024,7 +1024,7 @@ class VolumeTaylorExpansionFactory(ExpansionFactoryBase):

@override
def get_local_expansion_class(
self, base_kernel: Kernel, /
self, base_kernel: ScalarKernel, /
) -> type[LocalExpansionBase]:
"""
:returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
Expand All @@ -1034,7 +1034,7 @@ def get_local_expansion_class(

@override
def get_multipole_expansion_class(
self, base_kernel: Kernel, /
self, base_kernel: ScalarKernel, /
) -> type[MultipoleExpansionBase]:
"""
:returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
Expand All @@ -1050,7 +1050,7 @@ class DefaultExpansionFactory(ExpansionFactoryBase):

@override
def get_local_expansion_class(self,
base_kernel: Kernel, /,
base_kernel: ScalarKernel, /,
) -> LocalExpansionFactory:
"""
:returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
Expand All @@ -1067,7 +1067,7 @@ def get_local_expansion_class(self,

@override
def get_multipole_expansion_class(self,
base_kernel: Kernel, /,
base_kernel: ScalarKernel, /,
) -> MultipoleExpansionFactory:
"""
:returns: a subclass of :class:`ExpansionBase` suitable for *base_kernel*.
Expand Down
6 changes: 3 additions & 3 deletions sumpy/expansion/level_to_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from collections.abc import Sequence

import sumpy.symbolic as sym
from sumpy.kernel import Kernel
from sumpy.kernel import ScalarKernel


class TreeLike(Protocol):
Expand Down Expand Up @@ -71,7 +71,7 @@ class FMMLibExpansionOrderFinder:
"""

def __call__(self,
kernel: Kernel,
kernel: ScalarKernel,
kernel_args: dict[str, sym.Expr] | Sequence[tuple[str, sym.Expr]],
tree: TreeLike,
level: int) -> int:
Expand Down Expand Up @@ -163,7 +163,7 @@ class SimpleExpansionOrderFinder:
"""

def __call__(self,
kernel: Kernel,
kernel: ScalarKernel,
kernel_args: dict[str, sym.Expr] | Sequence[tuple[str, sym.Expr]],
tree: TreeLike,
level: int) -> int:
Expand Down
16 changes: 8 additions & 8 deletions sumpy/expansion/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
HankelBased2DMultipoleExpansion,
MultipoleExpansionBase,
)
from sumpy.kernel import Kernel
from sumpy.kernel import ScalarKernel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_coefficient_identifiers(self) -> Sequence[MultiIndex]:

@override
def coefficients_from_source(self,
kernel: Kernel,
kernel: ScalarKernel,
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand Down Expand Up @@ -171,7 +171,7 @@ def coefficients_from_source(self,

@override
def evaluate(self,
kernel: Kernel,
kernel: ScalarKernel,
coeffs: Sequence[sym.Expr],
bvec: sym.Matrix,
rscale: sym.Expr,
Expand Down Expand Up @@ -224,7 +224,7 @@ def m2l_translation(self) -> M2LTranslationBase:

@override
def coefficients_from_source_vec(self,
kernels: Sequence[Kernel],
kernels: Sequence[ScalarKernel],
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand Down Expand Up @@ -272,7 +272,7 @@ def save_temp(x: sym.Expr) -> sym.Expr:

@override
def coefficients_from_source(self,
kernel: Kernel,
kernel: ScalarKernel,
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand All @@ -283,7 +283,7 @@ def coefficients_from_source(self,

@override
def evaluate(self,
kernel: Kernel,
kernel: ScalarKernel,
coeffs: Sequence[sym.Expr],
bvec: sym.Matrix,
rscale: sym.Expr,
Expand Down Expand Up @@ -556,7 +556,7 @@ def get_coefficient_identifiers(self) -> Sequence[MultiIndex]:

@override
def coefficients_from_source(self,
kernel: Kernel,
kernel: ScalarKernel,
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand All @@ -580,7 +580,7 @@ def coefficients_from_source(self,

@override
def evaluate(self,
kernel: Kernel,
kernel: ScalarKernel,
coeffs: Sequence[sym.Expr],
bvec: sym.Matrix,
rscale: sym.Expr,
Expand Down
8 changes: 5 additions & 3 deletions sumpy/expansion/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@
from pymbolic.typing import ArithmeticExpression

from sumpy.expansion import ExpansionBase
from sumpy.kernel import Kernel
from sumpy.kernel import ScalarKernel


logger = logging.getLogger(__name__)


def make_e2p_loopy_kernel(
expansion: ExpansionBase, kernels: Sequence[Kernel]) -> lp.TranslationUnit:
expansion: ExpansionBase,
kernels: Sequence[ScalarKernel],
) -> lp.TranslationUnit:
"""
A helper function that creates a :mod:`loopy` kernel for multipole/local evaluation.

Expand Down Expand Up @@ -152,7 +154,7 @@ def make_e2p_loopy_kernel(

def make_p2e_loopy_kernel(
expansion: ExpansionBase,
kernels: Sequence[Kernel],
kernels: Sequence[ScalarKernel],
strength_usage: Sequence[int],
nstrengths: int) -> lp.TranslationUnit:
"""
Expand Down
10 changes: 5 additions & 5 deletions sumpy/expansion/m2l.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

from sumpy.assignment_collection import SymbolicAssignmentCollection
from sumpy.expansion.diff_op import MultiIndex
from sumpy.kernel import Kernel
from sumpy.kernel import ScalarKernel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +88,7 @@ class M2LTranslationClassFactoryBase(ABC):
@abstractmethod
def get_m2l_translation_class(
self,
base_kernel: Kernel,
base_kernel: ScalarKernel,
local_expansion_class: type[LocalExpansionBase]
) -> type[M2LTranslationBase]:
"""
Expand All @@ -105,7 +105,7 @@ class NonFFTM2LTranslationClassFactory(M2LTranslationClassFactoryBase):
@override
def get_m2l_translation_class(
self,
base_kernel: Kernel,
base_kernel: ScalarKernel,
local_expansion_class: type[LocalExpansionBase]
) -> type[M2LTranslationBase]:
from sumpy.expansion.local import (
Expand All @@ -129,7 +129,7 @@ class FFTM2LTranslationClassFactory(M2LTranslationClassFactoryBase):
@override
def get_m2l_translation_class(
self,
base_kernel: Kernel,
base_kernel: ScalarKernel,
local_expansion_class: type[LocalExpansionBase]
) -> type[M2LTranslationBase]:
from sumpy.expansion.local import (
Expand All @@ -153,7 +153,7 @@ class DefaultM2LTranslationClassFactory(M2LTranslationClassFactoryBase):
@override
def get_m2l_translation_class(
self,
base_kernel: Kernel,
base_kernel: ScalarKernel,
local_expansion_class: type[LocalExpansionBase]
) -> type[M2LTranslationBase]:
from sumpy.expansion.local import (
Expand Down
12 changes: 6 additions & 6 deletions sumpy/expansion/multipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

from sumpy.assignment_collection import SymbolicAssignmentCollection
from sumpy.expansion.diff_op import MultiIndex
from sumpy.kernel import Kernel
from sumpy.kernel import ScalarKernel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,7 +76,7 @@ class VolumeTaylorMultipoleExpansionBase(VolumeTaylorExpansionMixin,

@override
def coefficients_from_source_vec(self,
kernels: Sequence[Kernel],
kernels: Sequence[ScalarKernel],
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand Down Expand Up @@ -116,7 +116,7 @@ def coefficients_from_source_vec(self,
@override
def coefficients_from_source(
self,
kernel: Kernel,
kernel: ScalarKernel,
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand All @@ -130,7 +130,7 @@ def coefficients_from_source(

@override
def evaluate(self,
kernel: Kernel,
kernel: ScalarKernel,
coeffs: Sequence[sym.Expr],
bvec: sym.Matrix,
rscale: sym.Expr,
Expand Down Expand Up @@ -440,7 +440,7 @@ def get_coefficient_identifiers(self) -> Sequence[MultiIndex]:
@override
def coefficients_from_source(
self,
kernel: Kernel,
kernel: ScalarKernel,
avec: sym.Matrix,
bvec: sym.Matrix | None,
rscale: sym.Expr,
Expand Down Expand Up @@ -469,7 +469,7 @@ def coefficients_from_source(

@override
def evaluate(self,
kernel: Kernel,
kernel: ScalarKernel,
coeffs: Sequence[sym.Expr],
bvec: sym.Matrix,
rscale: sym.Expr,
Expand Down
Loading
Loading