2016-05-24 4 views
7

n लंबाई n, arr1 और arr2 पर दो ndarrays पर विचार करें। मैं उत्पादों के निम्नलिखित राशि की गणना कर रहा हूँ, और यह बेंचमार्क num_runs बार कर रही है:उत्पाद का कुशल डबल योग

import numpy as np 
import time 

num_runs = 1000 
n = 100 

arr1 = np.random.rand(n) 
arr2 = np.random.rand(n) 

start_comp = time.clock() 
for r in xrange(num_runs): 
    sum_prods = np.sum([arr1[i]*arr2[j] for i in xrange(n) 
         for j in xrange(i+1, n)]) 

print "total time for comprehension = ", time.clock() - start_comp 

start_loop = time.clock() 
for r in xrange(num_runs): 
    sum_prod = 0.0 
    for i in xrange(n): 
     for j in xrange(i+1, n): 
      sum_prod += arr1[i]*arr2[j] 

print "total time for loop = ", time.clock() - start_loop 

उत्पादन

total time for comprehension = 3.23097066953 
total time for comprehension = 3.9045544426 

तो सूची समझ का उपयोग करते हुए तेजी से प्रकट होता है।

क्या इस तरह के उत्पादों की गणना करने के लिए शायद नम्पी रूटीन का उपयोग करके, अधिक कुशल कार्यान्वयन है?

a = np.sum(np.triu(arr1[:,None]*arr2[None,:],1)) 
b = np.sum([arr1[i]*arr2[j] for i in xrange(n) for j in xrange(i+1, n)]) 
print a == b # True 

असल में, मैं arr1 और arr2 में जोड़ो में सभी तत्वों के उत्पाद की गणना numpy प्रसारण/vectorization की गति का लाभ लेने के के मूल्य का भुगतान कर रहा हूँ:

+0

क्या यह उपयोगी हो सकता है? https://stackoverflow.com/questions/9068478/how-to-parallelize-a-sum-calculation-in-python-numpy –

+0

बहुत प्रासंगिक लगता है: ['इट्रेटर निर्भरता के साथ मैट्रिक्स गुणा - NumPy'] (http: // stackoverflow.com/questions/36045510/matrix-multiplication-with-iterator-dependency-numpy)। – Divakar

उत्तर

12

हे के बजाय एक हे (एन) क्रम एल्गोरिथ्म में आपरेशन को पुनर्व्यवस्थित करें (एन^2), और उत्पादों और रकम के लिए NumPy का लाभ लें:

# arr1_weights[i] is the sum of all terms arr1[i] gets multiplied by in the 
# original version 
arr1_weights = arr2[::-1].cumsum()[::-1] - arr2 

sum_prods = arr1.dot(arr1_weights) 

समय से पता चलता है इस के बारे में होने n == 100 के लिए सूची समझ से 200 गुना तेज।

In [21]: %%timeit 
    ....: np.sum([arr1[i] * arr2[j] for i in range(n) for j in range(i+1, n)]) 
    ....: 
100 loops, best of 3: 5.13 ms per loop 

In [22]: %%timeit 
    ....: arr1_weights = arr2[::-1].cumsum()[::-1] - arr2 
    ....: sum_prods = arr1.dot(arr1_weights) 
    ....: 
10000 loops, best of 3: 22.8 µs per loop 
+1

बधाई। –

+1

शब्दों को व्यवस्थित करना, यह भी है: 'arr1 [: - 1]। Cumsum()। Dot (arr2 [1:]) '। –

3

आप निम्न प्रसारण चाल का उपयोग कर सकते निम्न स्तर के कोड में बहुत तेजी से किया जा रहा है।

और समय:

%timeit np.sum(np.triu(arr1[:,None]*arr2[None,:],1)) 
10000 loops, best of 3: 55.9 µs per loop 

%timeit np.sum([arr1[i]*arr2[j] for i in xrange(n) for j in xrange(i+1, n)]) 
1000 loops, best of 3: 1.45 ms per loop 
+0

या स्मृति पर सहेजने के लिए, हालांकि थोड़ा धीमा हो सकता है: 'आर, सी = एनपी.triu_indices (एन, 1); आउटपुट = एनपी डॉट (एआर 1 [आर], एआर 2 [सी]) '। – Divakar

8

एक vectorized रास्ता: np.sum(np.triu(np.multiply.outer(arr1,arr2),1))। एक 10x नए कारक के लिए

from numba import jit 
@jit 
def t(arr1,arr2): 
    s=0 
    for i in range(n): 
     for j in range(i+1,n): 
      s+= arr1[i]*arr2[j] 
    return s 

:

In [9]: %timeit np.sum(np.triu(np.multiply.outer(arr1,arr2),1)) 
1000 loops, best of 3: 272 µs per loop 

In [10]: %timeit np.sum([arr1[i]*arr2[j] for i in range(n) 
         for j in range(i+1, n)] 
100 loops, best of 3: 7.9 ms per loop 

In [11]: allclose(np.sum(np.triu(np.multiply.outer(arr1,arr2),1)), 
np.sum(np.triu(np.multiply.outer(arr1,arr2),1))) 
Out[11]: True 

एक और तेजी से approch Numba उपयोग करने के लिए है:

एक 30x सुधार के लिए

In [12]: %timeit t(arr1,arr2) 
10000 loops, best of 3: 21.1 µs per loop 

और @ user2357112 न्यूनतम जवाब का उपयोग कर ,

@jit 
def t2357112(arr1,arr2): 
    s=0 
    c=0 
    for i in range(n-2,-1,-1): 
     c += arr2[i+1] 
     s += arr1[i]*c 
    return s 

In [13]: %timeit t2357112(arr1,arr2) 
100000 loops, best of 3: 2.33 µs per loop 

के लिए

, बस आवश्यक संचालन कर रही है।

+0

numba समाधान अच्छा है क्योंकि इसमें किसी भी इंटरमीडिएट सरणी निर्माण – JoshAdel

+0

शामिल नहीं है, मुझे लगता है कि आपके कोड को numba में अनुवाद करते समय आपको सीमाएं गलत हो सकती हैं। ऐसा लगता है कि आप 'arr2' के अंतिम तत्व से पहले 'arr2 [n] 'तक पहुंचने का प्रयास कर रहे हैं। – user2357112

+0

आप सही हैं। चूंकि numba सीमाओं की जांच नहीं करता है, यह चुपचाप मेरे प्रयास पर सही परिणाम दिया ...... संपादित। –

संबंधित मुद्दे