사용자 도구

사이트 도구


prioritized_experience_replay

PER (Prioritized Experience Replay)

import random
 
from IPython import embed
import numpy as np
 
 
class PriorityReplayMemory(list):
 
    def __init__(self, capacity):
        self.capacity = capacity
 
    def _parent_idx(self, k):
        # return k // 2 - (0 if k % 2 == 1 else 1)
        return (k - 1) // 2
 
    def _heap_sort(self):
        child_idx = len(self) - 1
        parent_idx = self._parent_idx(child_idx)
        while parent_idx >= 0:
            if self[parent_idx] < self[child_idx]:
                self[parent_idx], self[child_idx] = self[child_idx], self[parent_idx]
                child_idx = parent_idx
                parent_idx = self._parent_idx(child_idx)
            else:
                break
 
    def append(self, item):
        super(PriorityReplayMemory, self).append(item)
        self._heap_sort()
        if len(self) > self.capacity:
            self.pop()
 
    def sample(self, k):
        return random.sample(self, k)
 
    def _print_tree(self, idx=0, depth=0, right=False):
        buff = ''
        if idx > len(self) - 1:
            return buff + '\n'
        else:
            indent = '    ' * depth if right else ''
            buff += indent + '{: 4d}'.format(self[idx])
            buff += self._print_tree(2*idx + 1, depth + 1, False)
            buff += self._print_tree(2*idx + 2, depth + 1, True)
        return buff
 
 
 
if __name__ == '__main__':
 
    memory = PriorityReplayMemory(2 ** 4 - 1)
 
    data = list(range(16))
    random.shuffle(data)
    for x in data:
        memory.append(x)
        print(memory)
        print(memory._print_tree())
 
    print(memory)
    embed()
prioritized_experience_replay.txt · 마지막으로 수정됨: 2024/03/23 02:38 저자 127.0.0.1