2019年2月24日日曜日

Reinforcement Learning An Introduction 第2版 4章

はじめに


 テキストReinforcement Learning An Introduction 第2版の4章で紹介されている以下3つの手法をpythonで実装する。
  1. Iterative Policy Evaluation
  2. Policy Iteration
  3. Value Iteration

コードの場所


  今回の全コードはここにある。  

Iterative Policy Evaluation


 この手法は状態価値関数$V(s)$についての次のベルマン方程式を使う。 \begin{equation} V(s)=\sum_{s^{\prime},a} P(s^{\prime}|s,a)\pi(a|s)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq1} \end{equation} この式の導出はここで行った。テキストに掲載されているIterative Policy Evaluationのアルゴリズムは以下の通りである(テキストPDF版から引用した)。
policy($\pi(a|s)$)を固定し、式(\ref{eq1})を反復法で解いているだけである。p.77(PDF版でなくハードカバー版)の図4.1の左側の列を再現するコードを以下に示す。 等確率($\pi(a|s)=0.25$)で上下左右を選択するpolicyを採用したものである。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from common import *  # noqa


def update(v, state, gamma, reward):
    updated_v = 0.0 
    for action in ACTIONS.values():
        next_state = state + action
        if not is_on_grid(next_state):
            next_state = state
        updated_v += 0.25 * (reward + gamma * v[next_state])
    return updated_v


def main():
    v = initialize_value_function(ROWS, COLS)
    k = 0 
    while True:
        delta = 0 
        for state in state_generator(ROWS, COLS):
            if is_terminal_state(state):
                continue
            tmp = v[state]
            v[state] = update(v, state, GAMMA, REWARD)
            delta = max(delta, abs(tmp - v[state]))
        k += 1
        if delta < THRESHOLD:
            break

    print("iteration size: {}".format(k))
    display_value_function(v)


if __name__ == "__main__":
    main()
実行結果は以下である。
iteration size: 88
value funtion:
[[  0.         -13.99330608 -19.99037659 -21.98940765]
 [-13.99330608 -17.99178568 -19.99108113 -19.99118312]
 [-19.99037659 -19.99108113 -17.99247411 -13.99438108]
 [-21.98940765 -19.99118312 -13.99438108   0.        ]]
小数点以下を四捨五入すれば、テキストの$k=\infty$の場合の数値と一致する。

Policy Iteration


 この手法は、Policy Evaluationに対しては式(\ref{eq1})を、Policy Improvementに対しては次のベルマン最適方程式を使う。 \begin{equation} V^{*}(s)=\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V^{*}(s^{\prime})\right] \label{eq2} \end{equation} ただし、これをそのまま使うのではなく、次のようにしてpolicy更新のために利用する。 \begin{equation} \pi(a|s)={\rm arg}\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq3} \end{equation} 右辺を最大にするものが$N$個ある場合、各行動の実現確率を$1/N$とする。最大にしない行動には確率0を割り振る。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
p.77(PDF版でなくハードカバー版)の図4.1の右側の列を再現するコードを以下に示す。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from common import *  # noqa


def evaluate_policy(rows, cols, v, policy, gamma, reward, threshold):
    while True:
        delta = 0 
        for state in state_generator(rows, cols):
            if is_terminal_state(state):
                continue
            tmp = v[state]
            v[state] = update_value_function(v, policy, state, gamma, reward)
            delta = max(delta, abs(tmp - v[state]))
        if delta < threshold:
            break


def update_value_function(v, policy, state, gamma, reward):
    action_prob = policy[state]
    updated_v = 0 
    for action_key, prob in zip(ACTIONS.keys(), action_prob):
        next_state = state + ACTIONS[action_key]
        if not is_on_grid(next_state):
            next_state = state
        updated_v += prob * (reward + gamma * v[next_state])
    return updated_v


def update_policy(state, reward, gamma, v): 
    results = {}
    for key in ACTIONS:
        next_state = ACTIONS[key] + state
        if not is_on_grid(next_state):
            next_state = state
        results[key] = reward + gamma * v[next_state]
    max_vs = max(results.values())
    return [k for k, val in results.items() if val == max_vs]


def improve_policy(rows, cols, policy, reward, gamma, v): 
    is_stable = True
    for state in state_generator(rows, cols):
        if is_terminal_state(state):
            continue
        old_action_prob = policy[state].copy()
        new_policy = update_policy(state, reward, gamma, v)
        overwrite_policy(state, new_policy, policy)
        if not np.all(old_action_prob == policy[state]):
            is_stable = False
    return is_stable


def main():
    v = initialize_value_function(ROWS, COLS)
    policy = initialize_policy(ROWS, COLS)
    while True:
        evaluate_policy(ROWS, COLS, v, policy, GAMMA, REWARD, THRESHOLD)
        is_stable = improve_policy(ROWS, COLS, policy, REWARD, GAMMA, v)
        if is_stable:
            break
    display_policy(policy)
    display_value_function(v)


if __name__ == "__main__":
    main()
実行結果は以下の通り。
policy ---
0 1 left
0 2 left
0 3 down:left
1 0 up
1 1 up:left
1 2 up:right:down:left
1 3 down
2 0 up
2 1 up:right:down:left
2 2 right:down
2 3 down
3 0 up:right
3 1 right
3 2 right
value function ---
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
policyの出力は、(行番号、列番号、矢印の向き)の順に並んでいる。

Value Iteration


 本手法は式(\ref{eq2})を用いて状態価値関数$V(s)$を更新する。policy$(\pi(a|s))$の決定には式(\ref{eq3})を使う。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
$V(s)$を決めた後の$\pi(a|s)$の算出には式(\ref{eq3})を使う。Policy Iterationのときと同じ問題に適用するコードは以下の通り。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from common import *  # noqa
import policy_iteration


def update(v, state, gamma, reward):
    results = {}
    for key in ACTIONS:
        next_state = state + ACTIONS[key]
        if not is_on_grid(next_state):
            next_state = state
        results[key] = reward + gamma * v[next_state]
    max_vs = max(results.values())
    return max_vs


def make_policy(v, rows, cols, reward, gamma):
    policy = initialize_policy(rows, cols)
    for state in state_generator(rows, cols):
        if is_terminal_state(state):
            continue
        op = policy_iteration.update_policy(state, reward, gamma, v)
        overwrite_policy(state, op, policy)
    return policy


def main():
    v = initialize_value_function(ROWS, COLS)
    while True:
        delta = 0 
        for state in state_generator(ROWS, COLS):
            if is_terminal_state(state):
                continue
            tmp = v[state]
            v[state] = update(v, state, GAMMA, REWARD)
            delta = max(delta, abs(tmp - v[state]))
        if delta < THRESHOLD:
            break
    policy = make_policy(v, ROWS, COLS, REWARD, GAMMA)
    display_policy(policy)
    display_value_function(v)


if __name__ == "__main__":
    main()
実行結果の以下の通り。
policy ---
0 1 left
0 2 left
0 3 down:left
1 0 up
1 1 up:left
1 2 up:right:down:left
1 3 down
2 0 up
2 1 up:right:down:left
2 2 right:down
2 3 down
3 0 up:right
3 1 right
3 2 right
value function ---
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
Policy Iterationのときと同じ結果である。

まとめ


「Policy Iteration」と「Value Iteration」の結果は同じになるが、テキストp.77に掲載されている矢印の向きと少し違う。異なる箇所は、2行3列目と3行2列目のマスである。 間違いが分かる方、ご指摘ください。