ध्यान दें कि यह गणना करते समय, कम से कम ~ n × एन = 173 अरब परिचालन (समरूपता पर विचार नहीं) करने की आवश्यकता होगी तो यह जब तक numpy GPU या कुछ और की पहुंच है धीमी गति से किया जाएगा। एक आधुनिक कंप्यूटर पर ~ 3 गीगाहर्ट्ज सीपीयू के साथ, पूरे गणना में पूरा होने के लिए लगभग 60 सेकंड लगने की उम्मीद है, जिसमें कोई सिम/समांतर गति नहीं है।
#!/usr/bin/env python3
import numpy
import time
numpy.random.seed(0)
n = 120
N = 1000
X = numpy.random.random((N, n))
start_time = time.time()
M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)
end_time = time.time()
print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)
इस बारे में 8 सेकंड लेता है:
परीक्षण के लिए, हम एन = 1000 के साथ शुरू में हम इसका उपयोग सही होने और प्रदर्शन की जांच करने होंगे। यह आधार रेखा है।
के विकल्प के रूप में 3 नेस्टेड लूप साथ शुरू करते हैं:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds
यह मोटे तौर पर आधे से एक मिनट लगता है, अच्छा नहीं! एक कारण यह है क्योंकि यह वास्तव में चार नेस्टेड लूप है: numpy.sum
को लूप भी माना जा सकता है।
हम ध्यान दें कि योग एक डॉट उत्पाद में बदल सकता है इस 4 पाश दूर करने के लिए:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds
बहुत बेहतर अब, लेकिन अभी भी धीमी गति से। लेकिन हम ध्यान दें कि डॉट उत्पाद एक पाश दूर करने के लिए एक आव्यूह गुणन में बदला जा सकता:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds
हुह? अब यह einsum
से भी अधिक कुशल है! हम यह भी जांच सकते हैं कि उत्तर वास्तव में सही होना चाहिए।
क्या हम आगे जा सकते हैं? हाँ! हम द्वारा k
पाश को समाप्त कर सकता:
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = numpy.repeat(X[:,j], n).reshape((N, n))
M3[j] = (Y * X).T @ X
# ~0.3 seconds
हम यह भी (एक्स की प्रत्येक पंक्ति के लिए यानी a * [b,c] == [a*b, a*c]
) प्रसारण इस्तेमाल कर सकते हैं कर से बचने के लिए numpy.repeat
(धन्यवाद @Divakar):
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = X[:,j].reshape((N, 1))
## or, equivalently:
# Y = X[:, numpy.newaxis, j]
M3[j] = (Y * X).T @ X
# ~0.16 seconds
अगर हम पैमाने एन = 100000 के लिए कार्यक्रम में 16 सेकंड लगने की उम्मीद है, जो सैद्धांतिक सीमा के भीतर है, इसलिए j
को समाप्त करने से बहुत मदद नहीं मिल सकती है (लेकिन यह कोड को समझने में वाकई मुश्किल हो सकता है)। हम इसे अंतिम समाधान के रूप में स्वीकार कर सकते हैं।
नोट: यदि आप अजगर 2, उपयोग कर रहे हैं a @ b
a.dot(b)
के बराबर है।
महान उत्तर, धन्यवाद! –
वास्तव में महान विचार। अगर मैं यहां कुछ प्रसारण कर सकता हूं, तो हम 'वाई' बनाने से बच सकते हैं और सीधे पुनरावृत्ति आउटपुट प्राप्त कर सकते हैं: '(एक्स [:, कोई नहीं, जे] * एक्स)। @ एक्स'। इससे हमें कुछ और प्रदर्शन बढ़ावा देना चाहिए। – Divakar
@ दिवाकर: धन्यवाद! अपडेट किया गया। – kennytm