r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""A Module to safely parse/evaluate Mathematical Expressions"""

2import ast

3import operator as op

4import math

6from numpy import int64

8# Sets the limit of how high the number can get to prevent DNS attacks

9max_value = 1e17

12# Redefine mathematical operations to prevent DNS attacks

14 """Redefine add function to prevent too large numbers"""

15 if any(abs(n) > max_value for n in [a, b]):

16 raise ValueError((a, b))

20def sub(a, b):

21 """Redefine sub function to prevent too large numbers"""

22 if any(abs(n) > max_value for n in [a, b]):

23 raise ValueError((a, b))

24 return op.sub(a, b)

27def mul(a, b):

28 """Redefine mul function to prevent too large numbers"""

29 if a == 0.0 or b == 0.0:

30 pass

31 elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value):

32 raise ValueError((a, b))

33 return op.mul(a, b)

36def div(a, b):

37 """Redefine div function to prevent too large numbers"""

38 if b == 0.0:

39 raise ValueError((a, b))

40 elif a == 0.0:

41 pass

42 elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value):

43 raise ValueError((a, b))

44 return op.truediv(a, b)

47def power(a, b):

48 """Redefine pow function to prevent too large numbers"""

49 if a == 0.0:

50 return 0.0

51 elif b / math.log(max_value, abs(a)) >= 1:

52 raise ValueError((a, b))

53 return op.pow(a, b)

56def exp(a):

57 """Redefine exp function to prevent too large numbers"""

58 if a > math.log(max_value):

59 raise ValueError(a)

60 return math.exp(a)

63# The list of allowed operators with defined functions they should operate on

64operators = {

66 ast.Sub: sub,

67 ast.Mult: mul,

68 ast.Div: div,

69 ast.Pow: power,

70 ast.USub: op.neg,

71 ast.Mod: op.mod,

72 ast.FloorDiv: op.ifloordiv

73}

75# Take all functions from math module as allowed functions

76allowed_math_fxn = {

77 "sin": math.sin,

78 "cos": math.cos,

79 "tan": math.tan,

80 "asin": math.asin,

81 "acos": math.acos,

82 "atan": math.atan,

83 "atan2": math.atan2,

84 "hypot": math.hypot,

85 "sinh": math.sinh,

86 "cosh": math.cosh,

87 "tanh": math.tanh,

88 "asinh": math.asinh,

89 "acosh": math.acosh,

90 "atanh": math.atanh,

92 "degrees": math.degrees,

93 "sqrt": math.sqrt,

94 "log": math.log,

95 "log10": math.log10,

96 "log2": math.log2,

97 "fmod": math.fmod,

98 "abs": math.fabs,

99 "ceil": math.ceil,

100 "floor": math.floor,

101 "round": round,

102 "exp": exp,

103}

106def get_function(node):

107 """Get the function from an ast.node"""

109 # The function call can be to a bare function or a module.function

110 if isinstance(node.func, ast.Name):

111 return node.func.id

112 elif isinstance(node.func, ast.Attribute):

113 return node.func.attr

114 else:

115 raise TypeError("node.func is of the wrong type")

118def limit(max_=None):

119 """Return decorator that limits allowed returned values."""

120 import functools

122 def decorator(func):

123 @functools.wraps(func)

124 def wrapper(*args, **kwargs):

125 ret = func(*args, **kwargs)

126 try:

127 mag = abs(ret)

128 except TypeError:

129 pass # not applicable

130 else:

131 if mag > max_:

132 raise ValueError(ret)

133 if isinstance(ret, int):

134 ret = int64(ret)

135 return ret

137 return wrapper

139 return decorator

142@limit(max_=max_value)

143def _eval(node):

144 """Evaluate a mathematical expression string parsed by ast"""

145 # Allow evaluate certain types of operators

146 if isinstance(node, ast.Num): # <number>

147 return node.n

148 elif isinstance(node, ast.BinOp): # <left> <operator> <right>

149 return operators[type(node.op)](_eval(node.left), _eval(node.right))

150 elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1

151 return operators[type(node.op)](_eval(node.operand))

152 elif isinstance(node, ast.Call): # using math.function

153 func = get_function(node)

154 # Evaluate all arguments

155 evaled_args = [_eval(arg) for arg in node.args]

156 return allowed_math_fxn[func](*evaled_args)

157 elif isinstance(node, ast.Name):

158 if node.id.lower() == "pi":

159 return math.pi

160 elif node.id.lower() == "e":

161 return math.e

162 elif node.id.lower() == "tau":

163 return math.pi * 2.0

164 else:

165 raise TypeError(

166 "Found a str in the expression, either param_dct/the "

167 "expression has a mistake in the parameter names or "

168 "attempting to parse non-mathematical code")

169 else:

170 raise TypeError(node)

173def eval_expression(expression, param_dct=dict()):

174 """Parse a mathematical expression,

176 Replaces variables with the values in param_dict and solves the expression

178 """

179 if not isinstance(expression, str):

180 raise TypeError("The expression must be a string")

181 if len(expression) > 1e4:

182 raise ValueError("The expression is too long.")

184 expression_rep = expression.strip()

186 if "()" in expression_rep:

187 raise ValueError("Invalid operation in expression")

189 for key, val in param_dct.items():

190 expression_rep = expression_rep.replace(key, str(val))

192 return _eval(ast.parse(expression_rep, mode="eval").body)