Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

from __future__ import print_function, division 

 

from sympy import Number 

from sympy.core import Mul, Basic, sympify, Add 

from sympy.core.compatibility import range 

from sympy.functions import adjoint 

from sympy.matrices.expressions.transpose import transpose 

from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust, 

        do_one, new) 

from sympy.matrices.expressions.matexpr import (MatrixExpr, ShapeError, 

        Identity, ZeroMatrix) 

from sympy.matrices.matrices import MatrixBase 

 

 

class MatMul(MatrixExpr): 

    """ 

    A product of matrix expressions 

 

    Examples 

    ======== 

 

    >>> from sympy import MatMul, MatrixSymbol 

    >>> A = MatrixSymbol('A', 5, 4) 

    >>> B = MatrixSymbol('B', 4, 3) 

    >>> C = MatrixSymbol('C', 3, 6) 

    >>> MatMul(A, B, C) 

    A*B*C 

    """ 

    is_MatMul = True 

 

    def __new__(cls, *args, **kwargs): 

        check = kwargs.get('check', True) 

 

        args = list(map(sympify, args)) 

        obj = Basic.__new__(cls, *args) 

        factor, matrices = obj.as_coeff_matrices() 

        if check: 

            validate(*matrices) 

        return obj 

 

    @property 

    def shape(self): 

        matrices = [arg for arg in self.args if arg.is_Matrix] 

        return (matrices[0].rows, matrices[-1].cols) 

 

    def _entry(self, i, j, expand=True): 

        coeff, matrices = self.as_coeff_matrices() 

 

        if len(matrices) == 1:  # situation like 2*X, matmul is just X 

            return coeff * matrices[0][i, j] 

 

        head, tail = matrices[0], matrices[1:] 

        if len(tail) == 0: 

            raise ValueError("lenth of tail cannot be 0") 

        X = head 

        Y = MatMul(*tail) 

 

        from sympy.core.symbol import Dummy 

        from sympy.concrete.summations import Sum 

        from sympy.matrices import ImmutableMatrix 

        k = Dummy('k', integer=True) 

        if X.has(ImmutableMatrix) or Y.has(ImmutableMatrix): 

            return coeff*Add(*[X[i, k]*Y[k, j] for k in range(X.cols)]) 

        result = Sum(coeff*X[i, k]*Y[k, j], (k, 0, X.cols - 1)) 

        return result.doit() if expand else result 

 

    def as_coeff_matrices(self): 

        scalars = [x for x in self.args if not x.is_Matrix] 

        matrices = [x for x in self.args if x.is_Matrix] 

        coeff = Mul(*scalars) 

 

        return coeff, matrices 

 

    def as_coeff_mmul(self): 

        coeff, matrices = self.as_coeff_matrices() 

        return coeff, MatMul(*matrices) 

 

    def _eval_transpose(self): 

        return MatMul(*[transpose(arg) for arg in self.args[::-1]]).doit() 

 

    def _eval_adjoint(self): 

        return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit() 

 

    def _eval_trace(self): 

        factor, mmul = self.as_coeff_mmul() 

        if factor != 1: 

            from .trace import trace 

            return factor * trace(mmul.doit()) 

        else: 

            raise NotImplementedError("Can't simplify any further") 

 

    def _eval_determinant(self): 

        from sympy.matrices.expressions.determinant import Determinant 

        factor, matrices = self.as_coeff_matrices() 

        square_matrices = only_squares(*matrices) 

        return factor**self.rows * Mul(*list(map(Determinant, square_matrices))) 

 

    def _eval_inverse(self): 

        try: 

            return MatMul(*[ 

                arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1 

                    for arg in self.args[::-1]]).doit() 

        except ShapeError: 

            from sympy.matrices.expressions.inverse import Inverse 

            return Inverse(self) 

 

    def doit(self, **kwargs): 

        deep = kwargs.get('deep', True) 

        if deep: 

            args = [arg.doit(**kwargs) for arg in self.args] 

        else: 

            args = self.args 

        return canonicalize(MatMul(*args)) 

 

def validate(*matrices): 

    """ Checks for valid shapes for args of MatMul """ 

    for i in range(len(matrices)-1): 

        A, B = matrices[i:i+2] 

        if A.cols != B.rows: 

            raise ShapeError("Matrices %s and %s are not aligned"%(A, B)) 

 

# Rules 

 

 

def newmul(*args): 

    if args[0] == 1: 

        args = args[1:] 

    return new(MatMul, *args) 

 

def any_zeros(mul): 

    if any([arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix) 

                       for arg in mul.args]): 

        matrices = [arg for arg in mul.args if arg.is_Matrix] 

        return ZeroMatrix(matrices[0].rows, matrices[-1].cols) 

    return mul 

 

def merge_explicit(matmul): 

    """ Merge explicit MatrixBase arguments 

 

    >>> from sympy import MatrixSymbol, eye, Matrix, MatMul, pprint 

    >>> from sympy.matrices.expressions.matmul import merge_explicit 

    >>> A = MatrixSymbol('A', 2, 2) 

    >>> B = Matrix([[1, 1], [1, 1]]) 

    >>> C = Matrix([[1, 2], [3, 4]]) 

    >>> X = MatMul(A, B, C) 

    >>> pprint(X) 

    A*[1  1]*[1  2] 

      [    ] [    ] 

      [1  1] [3  4] 

    >>> pprint(merge_explicit(X)) 

    A*[4  6] 

      [    ] 

      [4  6] 

 

    >>> X = MatMul(B, A, C) 

    >>> pprint(X) 

    [1  1]*A*[1  2] 

    [    ]   [    ] 

    [1  1]   [3  4] 

    >>> pprint(merge_explicit(X)) 

    [1  1]*A*[1  2] 

    [    ]   [    ] 

    [1  1]   [3  4] 

    """ 

    if not any(isinstance(arg, MatrixBase) for arg in matmul.args): 

        return matmul 

    newargs = [] 

    last = matmul.args[0] 

    for arg in matmul.args[1:]: 

        if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)): 

            last = last * arg 

        else: 

            newargs.append(last) 

            last = arg 

    newargs.append(last) 

 

    return MatMul(*newargs) 

 

def xxinv(mul): 

    """ Y * X * X.I -> Y """ 

    factor, matrices = mul.as_coeff_matrices() 

    for i, (X, Y) in enumerate(zip(matrices[:-1], matrices[1:])): 

        try: 

            if X.is_square and Y.is_square and X == Y.inverse(): 

                I = Identity(X.rows) 

                return newmul(factor, *(matrices[:i] + [I] + matrices[i+2:])) 

        except ValueError:  # Y might not be invertible 

            pass 

 

    return mul 

 

def remove_ids(mul): 

    """ Remove Identities from a MatMul 

 

    This is a modified version of sympy.strategies.rm_id. 

    This is necesssary because MatMul may contain both MatrixExprs and Exprs 

    as args. 

 

    See Also 

    -------- 

        sympy.strategies.rm_id 

    """ 

    # Separate Exprs from MatrixExprs in args 

    factor, mmul = mul.as_coeff_mmul() 

    # Apply standard rm_id for MatMuls 

    result = rm_id(lambda x: x.is_Identity is True)(mmul) 

    if result != mmul: 

        return newmul(factor, *result.args)  # Recombine and return 

    else: 

        return mul 

 

def factor_in_front(mul): 

    factor, matrices = mul.as_coeff_matrices() 

    if factor != 1: 

        return newmul(factor, *matrices) 

    return mul 

 

rules = (any_zeros, remove_ids, xxinv, unpack, rm_id(lambda x: x == 1), 

         merge_explicit, factor_in_front, flatten) 

 

canonicalize = exhaust(typed({MatMul: do_one(*rules)})) 

 

def only_squares(*matrices): 

    """ factor matrices only if they are square """ 

    if matrices[0].rows != matrices[-1].cols: 

        raise RuntimeError("Invalid matrices being multiplied") 

    out = [] 

    start = 0 

    for i, M in enumerate(matrices): 

        if M.cols == matrices[start].rows: 

            out.append(MatMul(*matrices[start:i+1]).doit()) 

            start = i+1 

    return out 

 

 

from sympy.assumptions.ask import ask, Q 

from sympy.assumptions.refine import handlers_dict 

 

 

def refine_MatMul(expr, assumptions): 

    """ 

    >>> from sympy import MatrixSymbol, Q, assuming, refine 

    >>> X = MatrixSymbol('X', 2, 2) 

    >>> expr = X * X.T 

    >>> print(expr) 

    X*X' 

    >>> with assuming(Q.orthogonal(X)): 

    ...     print(refine(expr)) 

    I 

    """ 

    newargs = [] 

    last = expr.args[0] 

    for arg in expr.args[1:]: 

        if arg == last.T and ask(Q.orthogonal(arg), assumptions): 

            last = Identity(arg.shape[0]) 

        elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): 

            last = Identity(arg.shape[0]) 

        else: 

            newargs.append(last) 

            last = arg 

    newargs.append(last) 

 

    return MatMul(*newargs) 

 

 

handlers_dict['MatMul'] = refine_MatMul