Limnoria/plugins/Math/evaluator.py

170 lines
5.5 KiB
Python

###
# Copyright (c) 2019, Valentin Lorentz
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
import ast
import math
import cmath
import operator
class InvalidNode(Exception):
pass
def filter_module(module, safe_names):
return dict([
(name, getattr(module, name))
for name in safe_names
if hasattr(module, name)
])
UNARY_OPS = {
ast.UAdd: lambda x: x,
ast.USub: lambda x: -x,
}
BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.BitXor: operator.xor,
ast.BitOr: operator.or_,
ast.BitAnd: operator.and_,
}
MATH_CONSTANTS = 'e inf nan pi tau'.split()
SAFE_MATH_FUNCTIONS = (
'acos acosh asin asinh atan atan2 atanh copysign cos cosh degrees erf '
'erfc exp expm1 fabs fmod frexp fsum gamma hypot ldexp lgamma log log10 '
'log1p log2 modf pow radians remainder sin sinh tan tanh'
).split()
SAFE_CMATH_FUNCTIONS = (
'acos acosh asin asinh atan atanh cos cosh exp inf infj log log10 '
'nanj phase polar rect sin sinh tan tanh tau'
).split()
SAFE_ENV = filter_module(math, MATH_CONSTANTS + SAFE_MATH_FUNCTIONS)
SAFE_ENV.update(filter_module(cmath, SAFE_CMATH_FUNCTIONS))
def _sqrt(x):
if isinstance(x, complex) or x < 0:
return cmath.sqrt(x)
else:
return math.sqrt(x)
def _cbrt(x):
return math.pow(x, 1.0/3)
def _factorial(x):
if x<=10000:
return float(math.factorial(x))
else:
raise Exception('factorial argument too large')
SAFE_ENV.update({
'i': 1j,
'abs': abs,
'max': max,
'min': min,
'round': lambda x, y=0: round(x, int(y)),
'factorial': _factorial,
'sqrt': _sqrt,
'cbrt': _cbrt,
'ceil': lambda x: float(math.ceil(x)),
'floor': lambda x: float(math.floor(x)),
})
UNSAFE_ENV = SAFE_ENV.copy()
# Add functions that return integers
UNSAFE_ENV.update(filter_module(math, 'ceil floor factorial gcd'.split()))
# It would be nice if ast.literal_eval used a visitor so we could subclass
# to extend it, but it doesn't, so let's reimplement it entirely.
class SafeEvalVisitor(ast.NodeVisitor):
def __init__(self, allow_ints):
self._allow_ints = allow_ints
self._env = UNSAFE_ENV if allow_ints else SAFE_ENV
def _convert_num(self, x):
"""Converts numbers to complex if ints are not allowed."""
if self._allow_ints:
return x
else:
x = complex(x)
if x.imag == 0:
x = x.real
# Need to use string-formatting here instead of str() because
# use of str() on large numbers loses information:
# str(float(33333333333333)) => '3.33333333333e+13'
# float('3.33333333333e+13') => 33333333333300.0
return float('%.16f' % x)
else:
return x
def visit_Expression(self, node):
return self.visit(node.body)
def visit_Num(self, node):
return self._convert_num(node.n)
def visit_Name(self, node):
id_ = node.id.lower()
if id_ in self._env:
return self._env[id_]
else:
raise NameError(node.id)
def visit_Call(self, node):
func = self.visit(node.func)
args = map(self.visit, node.args)
# TODO: keywords?
return func(*args)
def visit_UnaryOp(self, node):
op = UNARY_OPS.get(node.op.__class__)
if op:
return op(self.visit(node.operand))
else:
raise InvalidNode('illegal operator %s' % node.op.__class__.__name__)
def visit_BinOp(self, node):
op = BIN_OPS.get(node.op.__class__)
if op:
return op(self.visit(node.left), self.visit(node.right))
else:
raise InvalidNode('illegal operator %s' % node.op.__class__.__name__)
def generic_visit(self, node):
raise InvalidNode('illegal construct %s' % node.__class__.__name__)
def safe_eval(text, allow_ints):
node = ast.parse(text, mode='eval')
return SafeEvalVisitor(allow_ints).visit(node)