2016-07-15 2 views
6

Numpy को जब मैं एक मैट्रिक्स N पंक्तियों और n कॉलम के साथ X के तीसरे क्रम क्षणों की गणना, मैं आमतौर पर einsum का उपयोग करें:वैकल्पिक einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N 

यह आमतौर पर ठीक काम करता है, लेकिन अब मैं बड़ा मूल्यों के साथ काम कर रहा हूँ, अर्थात् n = 120 और N = 100000, और einsum रिटर्न निम्न त्रुटि:

ValueError: iterator is too large

3 नेस्टेड छोरों करने का विकल्प unfeasable है, इसलिए मैं मैं सोच रहा हूं कि कोई विकल्प है या नहीं।

उत्तर

4

ध्यान दें कि यह गणना करते समय, कम से कम ~ n × एन = 173 अरब परिचालन (समरूपता पर विचार नहीं) करने की आवश्यकता होगी तो यह जब तक numpy GPU या कुछ और की पहुंच है धीमी गति से किया जाएगा। एक आधुनिक कंप्यूटर पर ~ 3 गीगाहर्ट्ज सीपीयू के साथ, पूरे गणना में पूरा होने के लिए लगभग 60 सेकंड लगने की उम्मीद है, जिसमें कोई सिम/समांतर गति नहीं है।

#!/usr/bin/env python3 

import numpy 
import time 

numpy.random.seed(0) 

n = 120 
N = 1000 
X = numpy.random.random((N, n)) 

start_time = time.time() 

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X) 

end_time = time.time() 

print('check:', M3[2,4,6], '= 125.401852515?') 
print('check:', M3[4,2,6], '= 125.401852515?') 
print('check:', M3[6,4,2], '= 125.401852515?') 
print('check:', numpy.sum(M3), '= 218028826.631?') 
print('total time =', end_time - start_time) 

इस बारे में 8 सेकंड लेता है:


परीक्षण के लिए, हम एन = 1000 के साथ शुरू में हम इसका उपयोग सही होने और प्रदर्शन की जांच करने होंगे। यह आधार रेखा है।

के विकल्प के रूप में 3 नेस्टेड लूप साथ शुरू करते हैं:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l]) 
# ~27 seconds 

यह मोटे तौर पर आधे से एक मिनट लगता है, अच्छा नहीं! एक कारण यह है क्योंकि यह वास्तव में चार नेस्टेड लूप है: numpy.sum को लूप भी माना जा सकता है।

हम ध्यान दें कि योग एक डॉट उत्पाद में बदल सकता है इस 4 पाश दूर करने के लिए:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l] 
# 14 seconds 

बहुत बेहतर अब, लेकिन अभी भी धीमी गति से। लेकिन हम ध्यान दें कि डॉट उत्पाद एक पाश दूर करने के लिए एक आव्यूह गुणन में बदला जा सकता:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     M3[j,k] = X[:,j] * X[:,k] @ X 
# ~0.5 seconds 

हुह? अब यह einsum से भी अधिक कुशल है! हम यह भी जांच सकते हैं कि उत्तर वास्तव में सही होना चाहिए।

क्या हम आगे जा सकते हैं? हाँ! हम द्वारा k पाश को समाप्त कर सकता:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = numpy.repeat(X[:,j], n).reshape((N, n)) 
    M3[j] = (Y * X).T @ X 
# ~0.3 seconds 

हम यह भी (एक्स की प्रत्येक पंक्ति के लिए यानी a * [b,c] == [a*b, a*c]) प्रसारण इस्तेमाल कर सकते हैं कर से बचने के लिए numpy.repeat (धन्यवाद @Divakar):

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = X[:,j].reshape((N, 1)) 
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j] 
    M3[j] = (Y * X).T @ X 
# ~0.16 seconds 

अगर हम पैमाने एन = 100000 के लिए कार्यक्रम में 16 सेकंड लगने की उम्मीद है, जो सैद्धांतिक सीमा के भीतर है, इसलिए j को समाप्त करने से बहुत मदद नहीं मिल सकती है (लेकिन यह कोड को समझने में वाकई मुश्किल हो सकता है)। हम इसे अंतिम समाधान के रूप में स्वीकार कर सकते हैं।


नोट: यदि आप अजगर 2, उपयोग कर रहे हैं a @ ba.dot(b) के बराबर है।

+0

महान उत्तर, धन्यवाद! –

+0

वास्तव में महान विचार। अगर मैं यहां कुछ प्रसारण कर सकता हूं, तो हम 'वाई' बनाने से बच सकते हैं और सीधे पुनरावृत्ति आउटपुट प्राप्त कर सकते हैं: '(एक्स [:, कोई नहीं, जे] * एक्स)। @ एक्स'। इससे हमें कुछ और प्रदर्शन बढ़ावा देना चाहिए। – Divakar

+0

@ दिवाकर: धन्यवाद! अपडेट किया गया। – kennytm

संबंधित मुद्दे