はじめに
テキストReinforcement Learning An Introduction 第2版の4章で紹介されている以下3つの手法をpythonで実装する。
- Iterative Policy Evaluation
- Policy Iteration
- Value Iteration
コードの場所
今回の全コードはここにある。
Iterative Policy Evaluation
この手法は状態価値関数$V(s)$についての次のベルマン方程式を使う。 \begin{equation} V(s)=\sum_{s^{\prime},a} P(s^{\prime}|s,a)\pi(a|s)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq1} \end{equation} この式の導出はここで行った。テキストに掲載されているIterative Policy Evaluationのアルゴリズムは以下の通りである(テキストPDF版から引用した)。 policy($\pi(a|s)$)を固定し、式(\ref{eq1})を反復法で解いているだけである。p.77(PDF版でなくハードカバー版)の図4.1の左側の列を再現するコードを以下に示す。 等確率($\pi(a|s)=0.25$)で上下左右を選択するpolicyを採用したものである。
#!/usr/bin/env python # -*- coding: utf-8 -*- from common import * # noqa def update(v, state, gamma, reward): updated_v = 0.0 for action in ACTIONS.values(): next_state = state + action if not is_on_grid(next_state): next_state = state updated_v += 0.25 * (reward + gamma * v[next_state]) return updated_v def main(): v = initialize_value_function(ROWS, COLS) k = 0 while True: delta = 0 for state in state_generator(ROWS, COLS): if is_terminal_state(state): continue tmp = v[state] v[state] = update(v, state, GAMMA, REWARD) delta = max(delta, abs(tmp - v[state])) k += 1 if delta < THRESHOLD: break print("iteration size: {}".format(k)) display_value_function(v) if __name__ == "__main__": main()実行結果は以下である。
iteration size: 88 value funtion: [[ 0. -13.99330608 -19.99037659 -21.98940765] [-13.99330608 -17.99178568 -19.99108113 -19.99118312] [-19.99037659 -19.99108113 -17.99247411 -13.99438108] [-21.98940765 -19.99118312 -13.99438108 0. ]]小数点以下を四捨五入すれば、テキストの$k=\infty$の場合の数値と一致する。
Policy Iteration
この手法は、Policy Evaluationに対しては式(\ref{eq1})を、Policy Improvementに対しては次のベルマン最適方程式を使う。 \begin{equation} V^{*}(s)=\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V^{*}(s^{\prime})\right] \label{eq2} \end{equation} ただし、これをそのまま使うのではなく、次のようにしてpolicy更新のために利用する。 \begin{equation} \pi(a|s)={\rm arg}\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq3} \end{equation} 右辺を最大にするものが$N$個ある場合、各行動の実現確率を$1/N$とする。最大にしない行動には確率0を割り振る。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。 p.77(PDF版でなくハードカバー版)の図4.1の右側の列を再現するコードを以下に示す。
#!/usr/bin/env python # -*- coding: utf-8 -*- import numpy as np from common import * # noqa def evaluate_policy(rows, cols, v, policy, gamma, reward, threshold): while True: delta = 0 for state in state_generator(rows, cols): if is_terminal_state(state): continue tmp = v[state] v[state] = update_value_function(v, policy, state, gamma, reward) delta = max(delta, abs(tmp - v[state])) if delta < threshold: break def update_value_function(v, policy, state, gamma, reward): action_prob = policy[state] updated_v = 0 for action_key, prob in zip(ACTIONS.keys(), action_prob): next_state = state + ACTIONS[action_key] if not is_on_grid(next_state): next_state = state updated_v += prob * (reward + gamma * v[next_state]) return updated_v def update_policy(state, reward, gamma, v): results = {} for key in ACTIONS: next_state = ACTIONS[key] + state if not is_on_grid(next_state): next_state = state results[key] = reward + gamma * v[next_state] max_vs = max(results.values()) return [k for k, val in results.items() if val == max_vs] def improve_policy(rows, cols, policy, reward, gamma, v): is_stable = True for state in state_generator(rows, cols): if is_terminal_state(state): continue old_action_prob = policy[state].copy() new_policy = update_policy(state, reward, gamma, v) overwrite_policy(state, new_policy, policy) if not np.all(old_action_prob == policy[state]): is_stable = False return is_stable def main(): v = initialize_value_function(ROWS, COLS) policy = initialize_policy(ROWS, COLS) while True: evaluate_policy(ROWS, COLS, v, policy, GAMMA, REWARD, THRESHOLD) is_stable = improve_policy(ROWS, COLS, policy, REWARD, GAMMA, v) if is_stable: break display_policy(policy) display_value_function(v) if __name__ == "__main__": main()実行結果は以下の通り。
policy --- 0 1 left 0 2 left 0 3 down:left 1 0 up 1 1 up:left 1 2 up:right:down:left 1 3 down 2 0 up 2 1 up:right:down:left 2 2 right:down 2 3 down 3 0 up:right 3 1 right 3 2 right value function --- [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]policyの出力は、(行番号、列番号、矢印の向き)の順に並んでいる。
Value Iteration
本手法は式(\ref{eq2})を用いて状態価値関数$V(s)$を更新する。policy$(\pi(a|s))$の決定には式(\ref{eq3})を使う。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。 $V(s)$を決めた後の$\pi(a|s)$の算出には式(\ref{eq3})を使う。Policy Iterationのときと同じ問題に適用するコードは以下の通り。
#!/usr/bin/env python # -*- coding: utf-8 -*- from common import * # noqa import policy_iteration def update(v, state, gamma, reward): results = {} for key in ACTIONS: next_state = state + ACTIONS[key] if not is_on_grid(next_state): next_state = state results[key] = reward + gamma * v[next_state] max_vs = max(results.values()) return max_vs def make_policy(v, rows, cols, reward, gamma): policy = initialize_policy(rows, cols) for state in state_generator(rows, cols): if is_terminal_state(state): continue op = policy_iteration.update_policy(state, reward, gamma, v) overwrite_policy(state, op, policy) return policy def main(): v = initialize_value_function(ROWS, COLS) while True: delta = 0 for state in state_generator(ROWS, COLS): if is_terminal_state(state): continue tmp = v[state] v[state] = update(v, state, GAMMA, REWARD) delta = max(delta, abs(tmp - v[state])) if delta < THRESHOLD: break policy = make_policy(v, ROWS, COLS, REWARD, GAMMA) display_policy(policy) display_value_function(v) if __name__ == "__main__": main()実行結果の以下の通り。
policy --- 0 1 left 0 2 left 0 3 down:left 1 0 up 1 1 up:left 1 2 up:right:down:left 1 3 down 2 0 up 2 1 up:right:down:left 2 2 right:down 2 3 down 3 0 up:right 3 1 right 3 2 right value function --- [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]Policy Iterationのときと同じ結果である。
まとめ
「Policy Iteration」と「Value Iteration」の結果は同じになるが、テキストp.77に掲載されている矢印の向きと少し違う。異なる箇所は、2行3列目と3行2列目のマスである。 間違いが分かる方、ご指摘ください。
0 件のコメント:
コメントを投稿