prioritized_experience_replay
PER (Prioritized Experience Replay)
import random from IPython import embed import numpy as np class PriorityReplayMemory(list): def __init__(self, capacity): self.capacity = capacity def _parent_idx(self, k): # return k // 2 - (0 if k % 2 == 1 else 1) return (k - 1) // 2 def _heap_sort(self): child_idx = len(self) - 1 parent_idx = self._parent_idx(child_idx) while parent_idx >= 0: if self[parent_idx] < self[child_idx]: self[parent_idx], self[child_idx] = self[child_idx], self[parent_idx] child_idx = parent_idx parent_idx = self._parent_idx(child_idx) else: break def append(self, item): super(PriorityReplayMemory, self).append(item) self._heap_sort() if len(self) > self.capacity: self.pop() def sample(self, k): return random.sample(self, k) def _print_tree(self, idx=0, depth=0, right=False): buff = '' if idx > len(self) - 1: return buff + '\n' else: indent = ' ' * depth if right else '' buff += indent + '{: 4d}'.format(self[idx]) buff += self._print_tree(2*idx + 1, depth + 1, False) buff += self._print_tree(2*idx + 2, depth + 1, True) return buff if __name__ == '__main__': memory = PriorityReplayMemory(2 ** 4 - 1) data = list(range(16)) random.shuffle(data) for x in data: memory.append(x) print(memory) print(memory._print_tree()) print(memory) embed()
prioritized_experience_replay.txt · 마지막으로 수정됨: 2024/03/23 02:38 저자 127.0.0.1