Random decomposition of list of elements in Python

in #python7 years ago (edited)

Recently I decided to improve my python programming skills by implementing some machine learning algorithms. One of the first problem to solve is decomposition of the data set to two separate sets: training set and validation set.

Decomposition algorithm should divide input set (which will be represented by python list) to two sets with given proportion. For example: proportion 0.5 will split input set to two sets with size 0.5 * n and (1 - 0.5) * n, where n is the size of the input set. Algorithm should work in pseudo random way. It means we want to have some kind of control over randomness that will allow to reproduce sequence of random numbers from random number generator. Thankfully random number generator from random package allows to set seed parameter.

The algorithm will look like this:

  1. Input: input_data, proportion, seed
  2. Set seed to input number generator
  3. Determine how many elements should be randomly picked from input set and placed in the output: n_left = proportion * len(input_data), n_right = len(input_data) - n_left
  4. Create output sets (lists): left_data = input_data[:], right_data = []. left_data is initially set to copy of input data. Later the algorithm will randomly move n_right elements from left_data to right_data
  5. Repeat n_right times: pick random number rand_idx from range [0, len(left_data)]. Move left_data[rand_idx] to the end of right_data
  6. Return (left_data, right_data)

Complexity of this algorithm is O(n^2), which is really bad. Intuition tells us that O(n) is achievable here. Removing element from left_data[rand_idx] and placing it at the end of right_data costs O(n) + O(1) time. If we reduce it to O(1) then whole algorithm will be O(n). We can use simple trick here: swap last element with element at rand_idx and later pop() it from the list. This trick will work here because order of elements does not need to be preserved.

Final python 3.6 code looks like that:

import random

def split(input_data, proportion, **options):
   """
   Splits list according to given proportion.
   """

   if 'seed' in options:
       random.seed(options['seed'])

   n_left = int(len(input_data) * proportion)
   left_data = input_data[:]
   right_data = []

    while len(left_data) != n_left:
       rand_idx = random.randrange(len(left_data))
       right_data.append(left_data[rand_idx])
       left_data[rand_idx] = left_data[len(left_data) - 1]
       left_data.pop()

   left_data.sort()
   right_data.sort()

   return (left_data, right_data)

if __name__ == '__main__':
   left, right = split([1, 2, 4, 6, 5, 9, 11, 22, 45], 0.6, seed=1)

   print('left', left)
   print('right', right)

Output:

left [5, 6, 9, 11, 22]
right [1, 2, 4, 45]