แก้โจทย์สัมภาษณ์งาน รวมพลังเซิร์ฟเวอร์ $O(n \log n)$
เนื่องจากเพื่อนเจอโจทย์เขียนโค้ดตอนสัมภาษณ์งานจากบริษัทฝั่งซิลิคอนวัลเลย์ โดยส่วนที่ให้คิดอัลกอริทึมมาแก้ปัญหาเอาจริงๆ แล้วก็น่าจะถือว่ายากเกินไปด้วยซ้ำเมื่อเทียบกับการเขียนโค้ดในโลกการทำงานจริง (และออกจะแปลกใจด้วยที่เลือกโจทย์เชิงแข่งเขียนโปรแกรมมาออก) ตรงนี้ก็เข้าใจว่าทางฝั่งคนสัมภาษณ์นั้นคงตั้งใจเลือกโจทย์ที่ยากมากๆ จนหลายคนไม่มีทางทำเสร็จทันเวลา ซึ่งเค้าคงไม่ได้ดูแค่ผลลัพธ์คำตอบสุดท้ายปลายทางเพียงอย่างเดียว แต่ดูกระบวนการคิดระหว่างทางไปด้วยว่าทำอย่างไรจึงเดินทางมาจนถึงจุดหมายนั้นๆ นอกจากนี้ก็อาจจะทดสอบเชิงจิตวิทยาไปในตัวด้วยหละมั้ง ว่าถ้าเจอโจทย์ที่ยากเกินกำลังแล้วเราจะรับมือมันยังไง (เหมือนกับแบบทดสอบโคบายาชิมารุไงหละ!)
ซึ่งโจทย์เจ้าปัญหานี้ก็สนุกดี และยังแก้ปัญหาได้จากหลายมุมมองอีกด้วย (และก็ดีแล้วที่ไม่เลือก fizz buzz ที่แม้จะวัดอะไรได้หลายอย่าง1 แต่ก็นับว่าค่อนข้างง่ายและเก่าเกินไปแล้ว) โดยข้อนี้ถือว่าสนุกมากเพราะทำได้ตั้งแต่ความเร็ว $O(n^3)$ โดยเขียนแบบง่ายๆ ตรงไปตรงมา ซึ่งสามารถเร่งความเร็วเป็น $O(n^2)$ ได้เมื่อประยุกต์ใช้การค้นหาอย่างชาญฉลาด แต่ก็ยังทำให้ดีขึ้นได้อีกด้วยการแบ่งแยกและเอาชนะที่ได้เวลาเหลือ $O(n \log n)$ แต่ไม่รับประกันว่าจะทำได้กับข้อมูลทุกรูปแบบ จนท้ายที่สุดต้องขุดเทคนิคกำหนดการพลวัตมาใช้เพื่อกดเวลาให้เป็น $O(n \log n)$ ทุกครั้ง … ก็ขอจดวิธีแก้โจทย์ด้วยวิธีอันหลากหลายเหล่านี้เก็บไว้หน่อย เผื่อว่าวันหน้าจะได้ร่วมสนุกไปสัมภาษณ์/เป็นคนสัมภาษณ์เองบ้าง 🤪
โจทย์: รวมพลังเซิร์ฟเวอร์
มองเซิร์ฟเวอร์ $n$ เครื่องที่วางเรียงต่อกันเป็นอาเรย์หนึ่งมิติ $X$ โดยแต่ละช่องในอาเรย์เก็บค่าจำนวนเต็มบวก แทนพลังของเซิร์ฟเวอร์เครื่องนั้นๆ และให้ $p(X[\ell \ldots r])$ เป็นวิธีคำนวณพลังของเซิร์ฟเวอร์จากเครื่องที่ $\ell$ ไปจนถึงเครื่องที่ $r$ เมื่อทำงานพร้อมกัน ซึ่งมีค่าเท่ากับ
\[p(X[\ell \ldots r]) = \min(X[\ell \ldots r]) \cdot \sum X[\ell \ldots r]\]โจทย์ให้หา $P(X)$ ซึ่งเป็นพลังรวมทั้งหมดเมื่อพิจจารณาทุกๆ วิธีเลือกเซิร์ฟเวอร์หลายเครื่องที่อยู่ติดกันมาทำงานด้วยกัน นั่นก็คือให้หา
\[P(X) = \sum_{i=0}^{n-1} \; \sum_{j=i}^{n-1} \; p(X[i \ldots j])\](จำไม่ได้ว่าข้อจำกัดของ $n$ เป็นเท่าไหร่ แต่ยังไงมาทำให้ดีที่สุดเท่าที่นึกออกกันเถอะ)
คิดตรงไปตรงมา $O(n^3)$
ถึงวิธีนี้จะให้เวลาแย่แน่ๆ แต่ข้อดีคือเขียนง่ายตรวจคำตอบได้ไวและมั่นใจว่าถูกต้อง ดังนั้นเขียนทิ้งไว้ก่อนไม่เสียหายหรอกเนอะ
def p(xs):
return min(xs) * sum(xs)
def total_power(xs):
n = len(xs)
return sum(p(xs[i:j+1]) for i in range(n) for j in range(i, n))
ซึ่งมันก็คือเขียนโค้ดตามที่โจทย์บอกเลยนี่หว่า 5555555555555555
อนึ่ง สังเกตว่าใน Python เวลาต้องการสไลซ์ซับอาเรย์ $X[i \ldots j]$ (เก็บตัวสุดท้ายในตำแหน่ง $j$ ด้วย) เราจะต้องเขียนโค้ดเป็น xs[i:j+1]
นั่นก็คือใส่ค่าตัวสุดท้ายบวกหนึ่งลงไป นี่เป็นหลักคิดการออกแบบของ Python ที่มองการสไลซ์อาเรย์ด้วยการกำหนดตัวเลขตำแหน่งลงตรงกลางระหว่างสมาชิกแต่ละตัว2 ซึ่งก็ดูไปด้วยกันได้ดีกับ range(n)
ที่จะนับเลขตั้งแต่ 0
ไปจนถึงแค่ n-1
เท่านั้น
ค้นหาค่าที่ต้องการให้เร็ว $O(n^2)$
ที่โค้ดในแบบก่อนหน้านั้นทำงานได้ช้าถึง $O(n^3)$ นั่นเพราะเราใช้ลูปถึงสามชั้น โดยสองชั้นแรกที่ฟังก์ชัน total_power
นั้นพอจะเห็นชัด แต่ลูปอีกชั้นที่ฟังก์ชัน p
นั้นโดนซ่อนไว้ภายใต้การเขียนอย่างเรียบง่ายของตัวภาษา ซึ่งก็คือฟังก์ชัน min
และ sum
ที่แต่ละตัวทำงานใน $O(n)$ นั่นเอง เราอาจมองว่าสองฟังก์ชันนี้เป็นฟังก์ชันเชิงค้นหาค่าที่ต้องการในช่วงที่กำหนด ซึ่งมันคงจะดีไม่น้อยถ้าเราสามารถย่นระยะเวลาการทำงานของพวกมันลงได้
สำหรับ sum
นั้นไม่ยาก เราสามารถเทคนิคแนวกำหนดการพลวัต โดยสังเกตว่าถ้าเราคำนวณอาเรย์ใหม่ทิ้งไว้ล่วงหน้า โดยแต่ละช่องเก็บผลรวมของอาเรย์ต้นฉบับจากช่องซ้ายสุดถึงช่องนั้นๆ (cumsum) จะทำให้ได้ว่า
เมื่อเราต้องการหาผลรวมในช่วงที่สนใจ เราก็สามารถจับผลรวมที่คำนวณทิ้งไว้ในอาเรย์ใหม่มาลบกันได้เลย หรือก็คือ
\[\sum X[i \ldots j] = Y[j] - Y[i{-}1]\]จะได้ว่าเวลาที่ใช้ในการสร้างอาเรย์ทิ้งไว้ล่วงหน้าคือ $O(n)$ และการค้นคำตอบแต่ละครั้งใช้เวลาเป็น $O(1)$
แต่กับฟังก์ชัน min
ไม่ได้ง่ายเช่นนั้น เพราะเราไม่สามารถใช้เทคนิคสร้างอาเรย์เพียงมิติเดียวแบบ sum
มาประยุกต์ใช้กับทุกช่วงการค้นหาได้อีกต่อไป ทางที่พอจะเขียนได้ง่ายก็คือสร้างอาเรย์สองมิติมาเก็บทุกความเป็นไปได้ของการค้นหาค่า min
แทน ดังนี้
ซึ่งจะได้ว่าเวลาที่ต้องคำนวณอาเรย์สองมิตินี้ทิ้งไว้เป็น $O(n^2)$ แต่ก็ยังพอมีเรื่องให้ใจชื้นคือเราสามารถหาค่า min
ในช่วงที่ต้องการด้วยเวลาเพียง $O(1)$ แล้ว (ยังมีวิธีเขียนการค้นค่าต่ำสุดในช่วงที่สนใจให้เร็วขึ้นกว่านี้ได้อีกหลายแบบ แต่สำหรับข้อนี้ไม่มีประโยชน์แล้ว เพราะเราติดคอขวดที่ $O(n^2)$ จากลูปผลรวมชั้นนอกอยู่ดี)
ดังนั้นโค้ดที่ปรับปรุงให้ทำงานได้เร็วใน $O(n^2)$ คือ
def init_query_sum(xs):
ys = [0 for _ in xs]
for i, x in enumerate(xs):
ys[i] = ys[i-1] + x
return ys
def query_sum(i, j, ys):
if i == 0:
return ys[j]
return ys[j] - ys[i-1]
def init_query_min(xs):
n = len(xs)
zs = [[1e400 for _ in range(n)] for _ in range(n)]
for i in range(n):
for j in range(i, n):
zs[i][j] = min(zs[i][j-1], xs[j])
return zs
def query_min(i, j, zs):
return zs[i][j]
def p(i, j, ys, zs):
return query_min(i, j, zs) * query_sum(i, j, ys)
def total_power(xs):
n = len(xs)
ys = init_query_sum(xs)
zs = init_query_min(xs)
return sum(p(i, j, ys, zs) for i in range(n) for j in range(i, n))
แบ่งแยกและเอาชนะ $O(n \log n)$ โดยเฉลี่ย
สังเกตว่าถ้าเราลองแจกแจงผลรวมที่ต้องการออกมาเพื่อหารูปแบบอะไรซักอย่าง เช่น สมมติสนใจอาเรย์ $X =[a,b,c,d,e,f]$ ที่มี $d$ เป็นค่าต่ำสุด ให้แต่ละแถวเขียนผลรวมของสมาชิกในแต่ละช่วง (ละการเขียนเครื่องหมายบวกเพื่อความสะอาดตา) เมื่อเราไล่เขียนทุกช่วงที่เป็นไปได้จนครบ เราจะพบรูปแบบเช่นนี้
\[\begin{array}{cccccc} a & b & c & d & e & f \\ \hline a \\ a & b \\ a & b & c \\ a & b & c & {\color{red}d} \\ a & b & c & {\color{red}d} & e \\ a & b & c & {\color{red}d} & e & f \\ & b \\ & b & c \\ & b & c & {\color{red}d} \\ & b & c & {\color{red}d} & e \\ & b & c & {\color{red}d} & e & f \\ & & c \\ & & c & {\color{red}d} \\ & & c & {\color{red}d} & e \\ & & c & {\color{red}d} & e & f \\ & & & {\color{red}d} \\ & & & {\color{red}d} & e \\ & & & {\color{red}d} & e & f \\ & & & & e \\ & & & & e & f \\ & & & & & f \end{array} \quad\quad\Rightarrow\quad\quad \begin{array}{cccccc} a & b & c & d & e & f \\ \hline a \\ a & b \\ a & b & c \\ & b \\ & b & c \\ & & c \\ \hdashline a & b & c & {\color{red}d} \\ a & b & c & {\color{red}d} & e \\ a & b & c & {\color{red}d} & e & f \\ & b & c & {\color{red}d} \\ & b & c & {\color{red}d} & e \\ & b & c & {\color{red}d} & e & f \\ & & c & {\color{red}d} \\ & & c & {\color{red}d} & e \\ & & c & {\color{red}d} & e & f \\ & & & {\color{red}d} \\ & & & {\color{red}d} & e \\ & & & {\color{red}d} & e & f \\ \hdashline & & & & e \\ & & & & e & f \\ & & & & & f \end{array}\]จะเห็นว่าเราสามารถแยกแถวที่มี $d$ กำกับอยู่ และจัดเรียงกั้นแถวข้างต้นออกเป็นได้เป็นสามพาทิชัน โดยพาทิชันตรงกลางมีทุกแถวที่มี $d$ กำกับ ส่วนพาทิชันบนจะมีทุกแถวที่ทุกตัวอยู่ฝั่งซ้ายของ $d$ เช่นเดียวกับพาทิชันล่างที่เก็บแถวที่อยู่ฝั่งขวาของ $d$ นี่ทำให้เราสามารถใช้การเรียกตัวเองจากเทคนิคแบ่งแยกและเอาชนะมาช่วยแก้โจทย์ได้ โดยโจทย์ในการเรียกตัวเองครั้งนี้จะเหลือเพียงแก้ปัญหาของพาทิชันตรงกลาง แล้วโยนพาทิชันที่เหลือไปแก้ปัญหาด้วยวิธีเดียวกันในการเรียกตัวเองครั้งยิบย่อยลงไป
ซึ่งปัญหาสำหรับพาทิชันตรงกลางนี้ก็คือเราต้องการหาผลรวมของทุกแถวที่คูณกับค่าต่ำสุดของแต่ละแถว เนื่องจากเราเลือกปัญหาที่เรารู้อยู่แล้วว่าค่าต่ำสุดนั้นเท่ากับ $d$ ดังนั้นงานที่เหลือก็แค่นับว่าสมาชิกแต่ละตัวปรากฏกี่ครั้งเท่านั้น
ให้ $S$ เป็นเซตของแถวทุกแถวที่เราเลือกมาในพาทิชันนี้ สังเกตว่าจำนวนที่สมาชิกแต่ละตัวที่ปรากฏเมื่อนับทุกแถวรวมกันคือ
\[\sum S = 3a + 6b + 9c + 12d + 8e + 4f\]สังเกตว่า $a$ มีตัวคูณ $3$ ติดอยู่ นั่นก็เพราะว่ามันปรากฏสามครั้ง ซึ่งก็คือ ครั้งแรกอยู่ในแถวที่จบด้วย $d$ ครั้งต่อมาจบที่ $e$ และครั้งสุดท้ายที่ $f$ นั่นก็คือเราสามารถนับได้ว่าจาก $d$ ไปจนสุดข้างขวาของอาเรย์มีความยาวเท่าไหร่ แล้วจึงเอามาคูณ $a$ ได้เลย
ส่วน $b$ นั้นมีตัวคูณ $3{\cdot}2$ ติดอยู่ โดยค่า $3$ นั้นอธิบายในทำนองเดียวกันกับ $3a$ ว่าแต่ละชุดมันกินแถวได้จาก $[d, e, f]$ ส่วนตัวคูณ $2$ นั้นเกิดขึ้นมาจากการที่ $b$ ปรากฏชุดแรกโดยมี $a$ อยู่ทางซ้ายสุด และปรากฏครั้งที่สองโดยตัวซ้ายสุดคือ $b$
ด้วยข้อสังเกตเช่นนี้ เราจึงจัดรูปได้ว่า
\[\begin{align} \sum S &= 3 \Big( 1a + 2b + 3c \Big) + \Big( 3 \cdot 4 \Big)d + 4 \Big( 2e + 1f \Big) \\ &= \left( \sum_{i=0}^{k-1} k(i{+}1) \cdot X[i] \right) + \Big( k(n{-}k) \cdot \min(X) \Big) + \left( \sum_{i=n-1}^{k+1} (n{-}k)(n{-}i) \cdot X[i] \right) \end{align}\]โดยที่ $k$ คือดัชนีของค่าที่ต่ำที่สุดในอาเรย์ สำหรับในที่นี้ค่าต่ำสุดคือตัว $d$ ซึ่งทำให้ได้ $k=3$ นั่นเอง
ดังนั้นจึงทำให้เราได้โค้ดเช่นนี้กลับมา
def total_power(xs):
if not xs:
return 0
min_value = min(xs)
k = xs.index(min_value)
left, right = xs[:k], xs[k+1:]
count = ( (len(right)+1) * sum(i*x for i, x in enumerate(left, 1))
+ (len(left)+1) * (len(right)+1) * min_value
+ (len(left)+1) * sum(i*x for i, x in enumerate(right[::-1], 1)) )
return count*min_value + total_power(left) + total_power(right)
เราใช้เวลา $O(n)$ ในการเรียกตัวเองหนึ่งชั้น ซึ่งถ้าเราสามารถแบ่งปัญหาออกเป็นพาทิชันบนล่างที่มีขนาดใหญ่พอๆ กันได้ ในแต่ละข้างที่เราเรียกตัวเองลึกลงไปขนาดของปัญหาจะเหลือประมาณครึ่งหนึ่ง เราจะเรียกตัวเองลึกลงไปเพียงแค่ $O(\log n)$ ชั้น จึงทำให้ได้ว่าอัลกอริทึมนี้ใช้เวลารวมเป็น $O(n \log n)$ หากเราโชคดีเจอข้อมูลนำเข้าแบบดังกล่าว
แต่ถ้าเราโชคร้าย เจอข้อมูลนำเข้าที่แบ่งพาทิชันได้ขนาดแตกต่างกันมากๆ เช่น ได้พาทิชันฝั่งบนเป็นอาเรย์ว่าง ส่วนพาทิชันฝั่งล่างมีขนาดใหญ่เกือบเท่าปัญหาเดิมเลย จะเห็นว่าเราต้องเรียกตัวเองลึกลงไป $O(n)$ ครั้งในข้อมูลนำเข้าอยู่ดี จนทำให้เวลาแย่ลงเป็น $O(n^2)$ ในที่สุด (อารมณ์เดียวกับการเรียงแบบเร็วที่ทำงานได้ช้าบนข้อมูลที่เรียงกลับด้าน)
กำหนดการพลวัต $O(n \log n)$ ทุกครั้ง
แล้วเราจะสามารถแก้ปัญหาได้เร็วแบบที่เชื่อถือไว้ใจได้ไม่ว่าจะเจอข้อมูลนำเข้าแบบไหนหรือเปล่า … ลองแกล้งๆ ลืมส่วนที่ต้องคูณค่า min
ในแต่ละแถวไปก่อน จากอาเรย์ $X=[a,b,c,d,e,f]$ เปลี่ยนไปเขียนแจกแจงผลรวมจากซ้ายสุดถึงแค่ตัว $e$ แบบนี้
หรี่ตามองดูจะเห็นปัญหาถูกแบ่งเป็นชั้นๆ แต่ละชั้นเป็นรูปสามเหลี่ยมเฉียงขึ้นมุมขวาบน ให้ $T[t]$ เป็นผลรวมของสามเหลี่ยมในชั้นที่ $t$ จะเห็นว่า
\[\begin{align} T[0] &= a \\ T[1] &= a + 2b \\ T[2] &= a + 2b + 3c \\ T[3] &= a + 2b + 3c + 4d \\ T[4] &= a + 2b + 3c + 4d + 5d \end{align}\]ดังนั้นเมื่อต้องการเพิ่ม $f$ ซึ่งก็คือสมาชิกตัวสุดท้ายเข้าไป เราก็แค่เอาคำตอบจากชั้นก่อนหน้ามาบวก $f$ เข้าไปในปริมาณที่ถูกต้อง ซึ่งก็คือ $T[5] = T[4] + 6f$ ก็พอแล้ว
และเราจะได้ว่าผลรวมพลังทั้งหมดที่โจทย์ต้องการ ก็คือ $P(X) = \sum_{t=0}^{n-1} T[t]$ นั่นเอง
ทีนี้ถ้าเรากลับมาคำนึงถึงปัญหาต้นฉบับ ที่แต่ละแถวนั้นต้องติดตัวคูณ min
เข้าไปด้วย สนใจเฉพาะชั้นสุดท้ายตอนกำลังจะเพิ่ม $f$ เข้าไป จะได้แผนภาพเช่นนี้
โดยที่ $m_i$ นั้นบอกค่าที่ต่ำที่สุดในแถวหนึ่งๆ ซึ่ง ณ ขณะนี้ $m_i = \min(X[i\ldots4])$ และทำให้ได้ว่า
\[\begin{align} T[4] &= m_0a + (m_0+m_1)b + \cdots + \sum_{i=0}^4 m_i e \end{align}\]เราต้องการเพิ่ม $f$ เข้าไปเพื่อสร้าง $T[5]$ ลองคิดเคสง่ายก่อน สมมติว่า $f$ นั้นมีค่ามากกว่าทุก $m_i$ เลย ดังนั้นการเพิ่ม $f$ เข้าไปในแต่ละแถวที่มีอยู่จึงไม่ไปเปลี่ยนแปลงค่า $m_i$ จึงได้ว่าตัวคูณที่เหมาะสมสำหรับ $f$ คือ
\[T[5] = T[4] + (m_0 + m_1 + m_2 + m_3 + m_4 + m_5)f\]โดยที่ $m_5=f$ นั่นก็เพราะว่าในแถวล่างสุดที่เพิ่มเข้ามาใหม่มีเพียง $f$ เป็นสมาชิกตัวเดียวนั่นเอง
แล้วเราจะรู้ได้เร็วแค่ไหนว่า $f$ นั้นมีค่ามากกว่า $m_i$ ทุกตัวก่อนหน้า? ตรงนี้ต้องอาศัยข้อสังเกตสำคัญคือ
\[m_0 \le m_1 \le m_2 \le m_3 \le m_4\]นั่นก็เพราะว่าที่แถวที่ $i$ เราวิ่งหาค่า min
จากสมาชิก ณ ตำแหน่งดัชนี $i$ ไปถึงด้านขวาสุดเท่าที่มีในตอนนั้น ดังนั้นตอนเพิ่มสมาชิกตัวที่ $j$ ตรงกลางทาง หากสมาชิกที่เพิ่มเข้ามานั้นมีค่าน้อยกว่า $m_i$ บางตัว สมาชิกตัวนั้นจะไปเปลี่ยนค่า $m_i$ ทุกตัวตั้งแต่แถวที่ $j$ เป็นต้นไปด้วยนั่นเอง
เมื่อเรารู้ว่า $m_i$ นั้นเรียงจากน้อยไปมากเสมอ เราจึงสามารถใช้การค้นหาทวิภาคที่เร็ว $O(\log n)$ มาเพื่อตอบคำถามได้ว่า $f$ มีค่ามากกว่าทุกตัวหรือไม่ และถ้ามันไม่ได้มีค่ามากที่สุด มันจะมีค่าน้อยกว่า $m_i$ ตัวไหนบ้าง
ซึ่งก็คือเรากำลังจะพิจารณาเคสยากที่ $f$ ไม่ได้มีค่ามากที่สุดแล้ว สมมติให้ดัชนีที่คืนมาจากการค้นหาทวิภาคคือ $k=3$ หรือก็คือ $m_2 < f \le m_3$ เราจะกลับไปเขียนแผนภาพสามเหลี่ยมใหม่ได้ว่า
\[\begin{array}{c:cccccc} \min & a & b & c & d & e & f \\ \hline m_0 & \color{blue}a & \color{blue}b & \color{blue}c & d & e & \color{red}f \\ m_1 && \color{blue}b & \color{blue}c & d & e & \color{red}f \\ m_2 &&& \color{blue}c & d & e & \color{red}f \\ {\color{red}f} &&&& \color{green}d & \color{green}e & \color{red}f \\ {\color{red}f} &&&&& \color{green}e & \color{red}f \\ {\color{red}f} &&&&&& \color{red}f \end{array}\]จะเห็นว่าเราสามารถแบ่งบริเวณต่างๆ สำหรับสามเหลี่ยมที่จะเอาไปคำนวณ $T[5]$ ได้เป็นสี่โซน ดังนี้
- โซนสามเหลี่ยมสีน้ำเงินด้านซ้าย จาก $a$ ไปจนถึง $c$ (กินความกว้าง $k$ ช่อง) เนื่องจากสามเหลี่ยมนี้แต่ละแถวยังคูณด้วยค่าต่ำสุดประจำแถวเหมือนเดิม จึงย้อนกลับไปใช้คำตอบจาก $T[2]$ ได้เลย
- โซนสามเหลี่ยมสีเขียวด้านล่าง จาก $d$ ไปจนถึง $e$ ซึ่งเราสามารถคำนวณสามเหลี่ยมนั้นได้เร็วใน $O(1)$ แล้วเอามาคูณกับ $f$ ซึ่งเป็นค่าต่ำสุดตัวใหม่ที่กินพื้นที่หลายแถวก็จบแล้ว
- โซนสี่เหลี่ยมสีดำตรงกลาง แนวนอนจาก $d$ ไปจนถึง $e$ และแนวตั้งจาก $m_0$ ไปจนถึง $m_2$ ค่าของก้อนนี้สามารถคำนวณได้จาก $\left(\sum_{i=k}^{4}X[i]\right)\left(\sum_{i=0}^{k-1} m_i\right)$ ซึ่งก็คือหาพื้นที่ของสี่เหลี่ยมด้วยการเอาความกว้างคูณความยาวนั่นเอง โดยในแต่ละด้านของสี่เหลี่ยมนั้นเราสามารถใช้เทคนิคคำนวณค่าทิ้งไว้แล้วค่อยค้นหาผลรวมได้อย่างรวดเร็วใน $O(1)$ เช่นเดิม
- โซนเส้นตรงสีแดงด้านขวา ก็คือ $\sum_{i=0}^5m_if$ ที่เราอัพเดทค่า $m_i$ บางตัวไปเรียบร้อยแล้ว … อนึ่งเราสามารถมองให้โซนนี้หายไปได้อีกด้วย โดยขยายขนาดของโซนสี่เหลี่ยมสีดำและโซนสามเหลี่ยมสีเขียวมาด้านขวาอีกหนึ่งช่องจนคลุม $f$ ไปด้วยนั่นเอง
ดังนั้นในแต่ละครั้งที่เราจะเพิ่มสมาชิกตัวใหม่เข้าไป เราจะกินเวลาเพียง $O(\log n)$ ไม่มากไปกว่านั้น และเราก็ต้องการเพิ่มสมาชิกเป็นจำนวน $O(n)$ ครั้ง จึงทำให้ใช้เวลารวมเป็น $O(n \log n)$ เท่านั้น
from bisect import bisect_left
class MonotoneBisector(object):
def __init__(self):
self.idxs = [-1]
self.mins = [-1e400]
self.sums = [0]
self._seen_elements = 0
def add(self, x):
i = bisect_left(self.mins, x)
del self.idxs[i:], self.mins[i:], self.sums[i:]
self.idxs += [self._seen_elements]
self.mins += [x]
self.sums += [self.sums[-1] + x*(self.idxs[-1]-self.idxs[-2])]
self._seen_elements += 1
return self.idxs[-2]
class PowerDP(object):
_slice_line = lambda s, k: s._line[-1] - s._line[k]
_slice_rect = lambda s, k: k * s._slice_line(k)
_slice_trig = lambda s, k: s._trig[-1] - s._trig[k] - s._slice_rect(k)
_left = lambda s, k: s._right[k]
_rect = lambda s, k: s._mono.sums[-2] * s._slice_line(k)
_down = lambda s, k: s._mono.mins[-1] * s._slice_trig(k)
def __init__(self):
self._mono = MonotoneBisector()
self._line = [0]
self._trig = [0]
self._right = [0]
def add(self, x):
k = 1 + self._mono.add(x)
self._line += [self._line[-1] + x]
self._trig += [self._trig[-1] + x*len(self._trig)]
self._right += [self._left(k) + self._rect(k) + self._down(k)]
return self._right[-1]
def total_power(xs):
dp = PowerDP()
return sum(dp.add(x) for x in xs)
author