Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

from __future__ import print_function, division 

 

from sympy.core.compatibility import reduce 

from operator import add 

 

from sympy.core import Add, Basic, sympify 

from sympy.functions import adjoint 

from sympy.matrices.matrices import MatrixBase 

from sympy.matrices.expressions.transpose import transpose 

from sympy.strategies import (rm_id, unpack, flatten, sort, condition, 

        exhaust, do_one, glom) 

from sympy.matrices.expressions.matexpr import MatrixExpr, ShapeError, ZeroMatrix 

from sympy.utilities import default_sort_key, sift 

 

 

class MatAdd(MatrixExpr): 

    """A Sum of Matrix Expressions 

 

    MatAdd inherits from and operates like SymPy Add 

 

    >>> from sympy import MatAdd, MatrixSymbol 

    >>> A = MatrixSymbol('A', 5, 5) 

    >>> B = MatrixSymbol('B', 5, 5) 

    >>> C = MatrixSymbol('C', 5, 5) 

    >>> MatAdd(A, B, C) 

    A + B + C 

    """ 

    is_MatAdd = True 

 

    def __new__(cls, *args, **kwargs): 

        args = list(map(sympify, args)) 

        check = kwargs.get('check', True) 

 

        obj = Basic.__new__(cls, *args) 

        if check: 

            validate(*args) 

        return obj 

 

    @property 

    def shape(self): 

        return self.args[0].shape 

 

    def _entry(self, i, j): 

        return Add(*[arg._entry(i, j) for arg in self.args]) 

 

    def _eval_transpose(self): 

        return MatAdd(*[transpose(arg) for arg in self.args]).doit() 

 

    def _eval_adjoint(self): 

        return MatAdd(*[adjoint(arg) for arg in self.args]).doit() 

 

    def _eval_trace(self): 

        from .trace import trace 

        return Add(*[trace(arg) for arg in self.args]).doit() 

 

    def doit(self, **kwargs): 

        deep = kwargs.get('deep', True) 

        if deep: 

            args = [arg.doit(**kwargs) for arg in self.args] 

        else: 

            args = self.args 

        return canonicalize(MatAdd(*args)) 

 

def validate(*args): 

    if not all(arg.is_Matrix for arg in args): 

        raise TypeError("Mix of Matrix and Scalar symbols") 

 

    A = args[0] 

    for B in args[1:]: 

        if A.shape != B.shape: 

            raise ShapeError("Matrices %s and %s are not aligned"%(A, B)) 

 

factor_of = lambda arg: arg.as_coeff_mmul()[0] 

matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1]) 

def combine(cnt, mat): 

    if cnt == 1: 

        return mat 

    else: 

        return cnt * mat 

 

 

def merge_explicit(matadd): 

    """ Merge explicit MatrixBase arguments 

 

    >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint 

    >>> from sympy.matrices.expressions.matadd import merge_explicit 

    >>> A = MatrixSymbol('A', 2, 2) 

    >>> B = eye(2) 

    >>> C = Matrix([[1, 2], [3, 4]]) 

    >>> X = MatAdd(A, B, C) 

    >>> pprint(X) 

    A + [1  0] + [1  2] 

        [    ]   [    ] 

        [0  1]   [3  4] 

    >>> pprint(merge_explicit(X)) 

    A + [2  2] 

        [    ] 

        [3  5] 

    """ 

    groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase)) 

    if len(groups[True]) > 1: 

        return MatAdd(*(groups[False] + [reduce(add, groups[True])])) 

    else: 

        return matadd 

 

 

rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)), 

         unpack, 

         flatten, 

         glom(matrix_of, factor_of, combine), 

         merge_explicit, 

         sort(default_sort_key)) 

 

canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd), 

                                 do_one(*rules)))