Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

from __future__ import print_function, division 

 

from sympy import Basic, Expr, sympify 

from sympy.matrices.matrices import MatrixBase 

from .matexpr import ShapeError 

 

 

class Trace(Expr): 

    """Matrix Trace 

 

    Represents the trace of a matrix expression. 

 

    >>> from sympy import MatrixSymbol, Trace, eye 

    >>> A = MatrixSymbol('A', 3, 3) 

    >>> Trace(A) 

    Trace(A) 

 

    See Also: 

        trace 

    """ 

    is_Trace = True 

 

    def __new__(cls, mat): 

        mat = sympify(mat) 

 

        if not mat.is_Matrix: 

            raise TypeError("input to Trace, %s, is not a matrix" % str(mat)) 

 

        if not mat.is_square: 

            raise ShapeError("Trace of a non-square matrix") 

 

        return Basic.__new__(cls, mat) 

 

    def _eval_transpose(self): 

        return self 

 

    @property 

    def arg(self): 

        return self.args[0] 

 

    def doit(self, **kwargs): 

        if kwargs.get('deep', True): 

            arg = self.arg.doit(**kwargs) 

            try: 

                return arg._eval_trace() 

            except (AttributeError, NotImplementedError): 

                return Trace(arg) 

        else: 

            # _eval_trace would go too deep here 

            if isinstance(self.arg, MatrixBase): 

                return trace(self.arg) 

            else: 

                return Trace(self.arg) 

 

 

    def _eval_rewrite_as_Sum(self): 

        from sympy import Sum, Dummy 

        i = Dummy('i') 

        return Sum(self.arg[i, i], (i, 0, self.arg.rows-1)).doit() 

 

 

def trace(expr): 

    """ Trace of a Matrix.  Sum of the diagonal elements 

 

    >>> from sympy import trace, Symbol, MatrixSymbol, pprint, eye 

    >>> n = Symbol('n') 

    >>> X = MatrixSymbol('X', n, n)  # A square matrix 

    >>> trace(2*X) 

    2*Trace(X) 

 

    >>> trace(eye(3)) 

    3 

 

    See Also: 

        Trace 

    """ 

    return Trace(expr).doit()