はじめに
テキストReinforcement Learning An Introduction 第2版の4章で紹介されている以下3つの手法をpythonで実装する。
- Iterative Policy Evaluation
- Policy Iteration
- Value Iteration
コードの場所
今回の全コードはここにある。
Iterative Policy Evaluation
この手法は状態価値関数$V(s)$についての次のベルマン方程式を使う。 \begin{equation} V(s)=\sum_{s^{\prime},a} P(s^{\prime}|s,a)\pi(a|s)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq1} \end{equation} この式の導出はここで行った。テキストに掲載されているIterative Policy Evaluationのアルゴリズムは以下の通りである(テキストPDF版から引用した)。 policy($\pi(a|s)$)を固定し、式(\ref{eq1})を反復法で解いているだけである。p.77(PDF版でなくハードカバー版)の図4.1の左側の列を再現するコードを以下に示す。 等確率($\pi(a|s)=0.25$)で上下左右を選択するpolicyを採用したものである。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from common import * # noqa
def update(v, state, gamma, reward):
updated_v = 0.0
for action in ACTIONS.values():
next_state = state + action
if not is_on_grid(next_state):
next_state = state
updated_v += 0.25 * (reward + gamma * v[next_state])
return updated_v
def main():
v = initialize_value_function(ROWS, COLS)
k = 0
while True:
delta = 0
for state in state_generator(ROWS, COLS):
if is_terminal_state(state):
continue
tmp = v[state]
v[state] = update(v, state, GAMMA, REWARD)
delta = max(delta, abs(tmp - v[state]))
k += 1
if delta < THRESHOLD:
break
print("iteration size: {}".format(k))
display_value_function(v)
if __name__ == "__main__":
main()
実行結果は以下である。
iteration size: 88 value funtion: [[ 0. -13.99330608 -19.99037659 -21.98940765] [-13.99330608 -17.99178568 -19.99108113 -19.99118312] [-19.99037659 -19.99108113 -17.99247411 -13.99438108] [-21.98940765 -19.99118312 -13.99438108 0. ]]小数点以下を四捨五入すれば、テキストの$k=\infty$の場合の数値と一致する。
Policy Iteration
この手法は、Policy Evaluationに対しては式(\ref{eq1})を、Policy Improvementに対しては次のベルマン最適方程式を使う。 \begin{equation} V^{*}(s)=\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V^{*}(s^{\prime})\right] \label{eq2} \end{equation} ただし、これをそのまま使うのではなく、次のようにしてpolicy更新のために利用する。 \begin{equation} \pi(a|s)={\rm arg}\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq3} \end{equation} 右辺を最大にするものが$N$個ある場合、各行動の実現確率を$1/N$とする。最大にしない行動には確率0を割り振る。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。 p.77(PDF版でなくハードカバー版)の図4.1の右側の列を再現するコードを以下に示す。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from common import * # noqa
def evaluate_policy(rows, cols, v, policy, gamma, reward, threshold):
while True:
delta = 0
for state in state_generator(rows, cols):
if is_terminal_state(state):
continue
tmp = v[state]
v[state] = update_value_function(v, policy, state, gamma, reward)
delta = max(delta, abs(tmp - v[state]))
if delta < threshold:
break
def update_value_function(v, policy, state, gamma, reward):
action_prob = policy[state]
updated_v = 0
for action_key, prob in zip(ACTIONS.keys(), action_prob):
next_state = state + ACTIONS[action_key]
if not is_on_grid(next_state):
next_state = state
updated_v += prob * (reward + gamma * v[next_state])
return updated_v
def update_policy(state, reward, gamma, v):
results = {}
for key in ACTIONS:
next_state = ACTIONS[key] + state
if not is_on_grid(next_state):
next_state = state
results[key] = reward + gamma * v[next_state]
max_vs = max(results.values())
return [k for k, val in results.items() if val == max_vs]
def improve_policy(rows, cols, policy, reward, gamma, v):
is_stable = True
for state in state_generator(rows, cols):
if is_terminal_state(state):
continue
old_action_prob = policy[state].copy()
new_policy = update_policy(state, reward, gamma, v)
overwrite_policy(state, new_policy, policy)
if not np.all(old_action_prob == policy[state]):
is_stable = False
return is_stable
def main():
v = initialize_value_function(ROWS, COLS)
policy = initialize_policy(ROWS, COLS)
while True:
evaluate_policy(ROWS, COLS, v, policy, GAMMA, REWARD, THRESHOLD)
is_stable = improve_policy(ROWS, COLS, policy, REWARD, GAMMA, v)
if is_stable:
break
display_policy(policy)
display_value_function(v)
if __name__ == "__main__":
main()
実行結果は以下の通り。
policy --- 0 1 left 0 2 left 0 3 down:left 1 0 up 1 1 up:left 1 2 up:right:down:left 1 3 down 2 0 up 2 1 up:right:down:left 2 2 right:down 2 3 down 3 0 up:right 3 1 right 3 2 right value function --- [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]policyの出力は、(行番号、列番号、矢印の向き)の順に並んでいる。
Value Iteration
本手法は式(\ref{eq2})を用いて状態価値関数$V(s)$を更新する。policy$(\pi(a|s))$の決定には式(\ref{eq3})を使う。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。 $V(s)$を決めた後の$\pi(a|s)$の算出には式(\ref{eq3})を使う。Policy Iterationのときと同じ問題に適用するコードは以下の通り。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from common import * # noqa
import policy_iteration
def update(v, state, gamma, reward):
results = {}
for key in ACTIONS:
next_state = state + ACTIONS[key]
if not is_on_grid(next_state):
next_state = state
results[key] = reward + gamma * v[next_state]
max_vs = max(results.values())
return max_vs
def make_policy(v, rows, cols, reward, gamma):
policy = initialize_policy(rows, cols)
for state in state_generator(rows, cols):
if is_terminal_state(state):
continue
op = policy_iteration.update_policy(state, reward, gamma, v)
overwrite_policy(state, op, policy)
return policy
def main():
v = initialize_value_function(ROWS, COLS)
while True:
delta = 0
for state in state_generator(ROWS, COLS):
if is_terminal_state(state):
continue
tmp = v[state]
v[state] = update(v, state, GAMMA, REWARD)
delta = max(delta, abs(tmp - v[state]))
if delta < THRESHOLD:
break
policy = make_policy(v, ROWS, COLS, REWARD, GAMMA)
display_policy(policy)
display_value_function(v)
if __name__ == "__main__":
main()
実行結果の以下の通り。
policy --- 0 1 left 0 2 left 0 3 down:left 1 0 up 1 1 up:left 1 2 up:right:down:left 1 3 down 2 0 up 2 1 up:right:down:left 2 2 right:down 2 3 down 3 0 up:right 3 1 right 3 2 right value function --- [[ 0. -1. -2. -3.] [-1. -2. -3. -2.] [-2. -3. -2. -1.] [-3. -2. -1. 0.]]Policy Iterationのときと同じ結果である。
まとめ
「Policy Iteration」と「Value Iteration」の結果は同じになるが、テキストp.77に掲載されている矢印の向きと少し違う。異なる箇所は、2行3列目と3行2列目のマスである。 間違いが分かる方、ご指摘ください。


