import sys
input = sys.stdin.readline
n,k = map(int, input().split())
arr = list(map(int,input().split()))
if k == 1:
print(max(arr) - min(arr))
exit()
diff = []
for i in range(n-1):
diff.append(arr[i+1]-arr[i])
diff.sort(reverse=True)
print(sum(diff[k-1:]))