manimgeo.math.base 源代码
from ..utils.config import GeoConfig
from typing import Union
from logging import getLogger
import functools
import numpy as np
type Number = Union[int, float]
cfg = GeoConfig()
logger = getLogger(__name__)
[文档]
def close(a: Union[np.ndarray, Number], b: Union[np.ndarray, Number]) -> bool:
"""
判断两个数值是否相近
"""
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return np.allclose(a, b, atol=cfg.atol, rtol=cfg.rtol)
elif isinstance(a, (int, float)) and isinstance(b, (int, float)):
if np.isnan(a) or np.isnan(b):
return False # NaN 永远不等于任何值,包括自身
if np.isinf(a) and np.isinf(b):
return (a == b) # 只有符号相同才相等 (inf == inf, -inf == -inf)
if np.isinf(a) or np.isinf(b):
return False # 一个是无穷大,另一个是有限数,则不相等
return abs(a - b) <= cfg.atol + cfg.rtol * abs(b)
else:
raise TypeError("不允许比较类型不同的两个数据是否一致: {} and {}".format(type(a), type(b)))
[文档]
def array2float(func):
"""
将参数中所有 np.ndarray 类型的参数自动转换为 float64
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
processed_args = []
# 处理位置参数
for arg in args:
if isinstance(arg, np.ndarray) and not np.issubdtype(arg.dtype, np.floating):
processed_args.append(arg.astype(np.float64))
if len(arg) <= 2:
logger.warning(f"参数 {arg} 维度少于 3,可能引发计算错误")
else:
processed_args.append(arg)
processed_kwargs = {}
# 处理关键字参数
for k, v in kwargs.items():
if isinstance(v, np.ndarray) and not np.issubdtype(v.dtype, np.floating):
processed_kwargs[k] = v.astype(np.float64)
if len(v) <= 2:
logger.warning(f"参数 {k}: {v} 维度少于 3,可能引发计算错误")
else:
processed_kwargs[k] = v
return func(*processed_args, **processed_kwargs)
return wrapper