はじめに
テキストReinforcement Learning An Introduction 第2版の4章で紹介されている以下3つの手法をpythonで実装する。
- Iterative Policy Evaluation
- Policy Iteration
- Value Iteration
コードの場所
今回の全コードはここにある。
Iterative Policy Evaluation
この手法は状態価値関数V(s)についての次のベルマン方程式を使う。 \begin{equation} V(s)=\sum_{s^{\prime},a} P(s^{\prime}|s,a)\pi(a|s)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq1} \end{equation}
この式の導出はここで行った。テキストに掲載されているIterative Policy Evaluationのアルゴリズムは以下の通りである(テキストPDF版から引用した)。
policy(\pi(a|s))を固定し、式(\ref{eq1})を反復法で解いているだけである。p.77(PDF版でなくハードカバー版)の図4.1の左側の列を再現するコードを以下に示す。
等確率(\pi(a|s)=0.25)で上下左右を選択するpolicyを採用したものである。
実行結果は以下である。
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- from common import * # noqa
- def update(v, state, gamma, reward):
- updated_v = 0.0
- for action in ACTIONS.values():
- next_state = state + action
- if not is_on_grid(next_state):
- next_state = state
- updated_v += 0.25 * (reward + gamma * v[next_state])
- return updated_v
- def main():
- v = initialize_value_function(ROWS, COLS)
- k = 0
- while True:
- delta = 0
- for state in state_generator(ROWS, COLS):
- if is_terminal_state(state):
- continue
- tmp = v[state]
- v[state] = update(v, state, GAMMA, REWARD)
- delta = max(delta, abs(tmp - v[state]))
- k += 1
- if delta < THRESHOLD:
- break
- print("iteration size: {}".format(k))
- display_value_function(v)
- if __name__ == "__main__":
- main()
iteration size: 88 value funtion: [[ 0. -13.99330608 -19.99037659 -21.98940765] [-13.99330608 -17.99178568 -19.99108113 -19.99118312] [-19.99037659 -19.99108113 -17.99247411 -13.99438108] [-21.98940765 -19.99118312 -13.99438108 0. ]]小数点以下を四捨五入すれば、テキストのk=\inftyの場合の数値と一致する。
Policy Iteration
この手法は、Policy Evaluationに対しては式(\ref{eq1})を、Policy Improvementに対しては次のベルマン最適方程式を使う。 \begin{equation} V^{*}(s)=\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V^{*}(s^{\prime})\right] \label{eq2} \end{equation}
ただし、これをそのまま使うのではなく、次のようにしてpolicy更新のために利用する。
\begin{equation}
\pi(a|s)={\rm arg}\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right]
\label{eq3}
\end{equation}
右辺を最大にするものがN個ある場合、各行動の実現確率を1/Nとする。最大にしない行動には確率0を割り振る。
テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
p.77(PDF版でなくハードカバー版)の図4.1の右側の列を再現するコードを以下に示す。
実行結果は以下の通り。
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import numpy as np
- from common import * # noqa
- def evaluate_policy(rows, cols, v, policy, gamma, reward, threshold):
- while True:
- delta = 0
- for state in state_generator(rows, cols):
- if is_terminal_state(state):
- continue
- tmp = v[state]
- v[state] = update_value_function(v, policy, state, gamma, reward)
- delta = max(delta, abs(tmp - v[state]))
- if delta < threshold:
- break
- def update_value_function(v, policy, state, gamma, reward):
- action_prob = policy[state]
- updated_v = 0
- for action_key, prob in zip(ACTIONS.keys(), action_prob):
- next_state = state + ACTIONS[action_key]
- if not is_on_grid(next_state):
- next_state = state
- updated_v += prob * (reward + gamma * v[next_state])
- return updated_v
- def update_policy(state, reward, gamma, v):
- results = {}
- for key in ACTIONS:
- next_state = ACTIONS[key] + state
- if not is_on_grid(next_state):
- next_state = state
- results[key] = reward + gamma * v[next_state]
- max_vs = max(results.values())
- return [k for k, val in results.items() if val == max_vs]
- def improve_policy(rows, cols, policy, reward, gamma, v):
- is_stable = True
- for state in state_generator(rows, cols):
- if is_terminal_state(state):
- continue
- old_action_prob = policy[state].copy()
- new_policy = update_policy(state, reward, gamma, v)
- overwrite_policy(state, new_policy, policy)
- if not np.all(old_action_prob == policy[state]):
- is_stable = False
- return is_stable
- def main():
- v = initialize_value_function(ROWS, COLS)
- policy = initialize_policy(ROWS, COLS)
- while True:
- evaluate_policy(ROWS, COLS, v, policy, GAMMA, REWARD, THRESHOLD)
- is_stable = improve_policy(ROWS, COLS, policy, REWARD, GAMMA, v)
- if is_stable:
- break
- display_policy(policy)
- display_value_function(v)
- if __name__ == "__main__":
- main()
policy --- 0 1 left 0 2 left 0 3 down:left 1 0 up 1 1 up:left 1 2 up:right:down:left 1 3 down 2 0 up 2 1 up:right:down:left 2 2 right:down 2 3 down 3 0 up:right 3 1 right 3 2 right value function --- [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]policyの出力は、(行番号、列番号、矢印の向き)の順に並んでいる。
Value Iteration
本手法は式(\ref{eq2})を用いて状態価値関数V(s)を更新する。policy(\pi(a|s))の決定には式(\ref{eq3})を使う。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。 V(s)を決めた後の\pi(a|s)の算出には式(\ref{eq3})を使う。Policy Iterationのときと同じ問題に適用するコードは以下の通り。
実行結果の以下の通り。
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- from common import * # noqa
- import policy_iteration
- def update(v, state, gamma, reward):
- results = {}
- for key in ACTIONS:
- next_state = state + ACTIONS[key]
- if not is_on_grid(next_state):
- next_state = state
- results[key] = reward + gamma * v[next_state]
- max_vs = max(results.values())
- return max_vs
- def make_policy(v, rows, cols, reward, gamma):
- policy = initialize_policy(rows, cols)
- for state in state_generator(rows, cols):
- if is_terminal_state(state):
- continue
- op = policy_iteration.update_policy(state, reward, gamma, v)
- overwrite_policy(state, op, policy)
- return policy
- def main():
- v = initialize_value_function(ROWS, COLS)
- while True:
- delta = 0
- for state in state_generator(ROWS, COLS):
- if is_terminal_state(state):
- continue
- tmp = v[state]
- v[state] = update(v, state, GAMMA, REWARD)
- delta = max(delta, abs(tmp - v[state]))
- if delta < THRESHOLD:
- break
- policy = make_policy(v, ROWS, COLS, REWARD, GAMMA)
- display_policy(policy)
- display_value_function(v)
- if __name__ == "__main__":
- main()
policy --- 0 1 left 0 2 left 0 3 down:left 1 0 up 1 1 up:left 1 2 up:right:down:left 1 3 down 2 0 up 2 1 up:right:down:left 2 2 right:down 2 3 down 3 0 up:right 3 1 right 3 2 right value function --- [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]Policy Iterationのときと同じ結果である。
まとめ
「Policy Iteration」と「Value Iteration」の結果は同じになるが、テキストp.77に掲載されている矢印の向きと少し違う。異なる箇所は、2行3列目と3行2列目のマスである。 間違いが分かる方、ご指摘ください。