Loading web-font TeX/Math/Italic

2019年2月24日日曜日

Reinforcement Learning An Introduction 第2版 4章

はじめに


 テキストReinforcement Learning An Introduction 第2版の4章で紹介されている以下3つの手法をpythonで実装する。
  1. Iterative Policy Evaluation
  2. Policy Iteration
  3. Value Iteration

コードの場所


  今回の全コードはここにある。  

Iterative Policy Evaluation


 この手法は状態価値関数V(s)についての次のベルマン方程式を使う。 \begin{equation} V(s)=\sum_{s^{\prime},a} P(s^{\prime}|s,a)\pi(a|s)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq1} \end{equation}
この式の導出はここで行った。テキストに掲載されているIterative Policy Evaluationのアルゴリズムは以下の通りである(テキストPDF版から引用した)。
policy(\pi(a|s))を固定し、式(\ref{eq1})を反復法で解いているだけである。p.77(PDF版でなくハードカバー版)の図4.1の左側の列を再現するコードを以下に示す。 等確率(\pi(a|s)=0.25)で上下左右を選択するpolicyを採用したものである。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from common import * # noqa
  4.  
  5.  
  6. def update(v, state, gamma, reward):
  7. updated_v = 0.0
  8. for action in ACTIONS.values():
  9. next_state = state + action
  10. if not is_on_grid(next_state):
  11. next_state = state
  12. updated_v += 0.25 * (reward + gamma * v[next_state])
  13. return updated_v
  14.  
  15.  
  16. def main():
  17. v = initialize_value_function(ROWS, COLS)
  18. k = 0
  19. while True:
  20. delta = 0
  21. for state in state_generator(ROWS, COLS):
  22. if is_terminal_state(state):
  23. continue
  24. tmp = v[state]
  25. v[state] = update(v, state, GAMMA, REWARD)
  26. delta = max(delta, abs(tmp - v[state]))
  27. k += 1
  28. if delta < THRESHOLD:
  29. break
  30.  
  31. print("iteration size: {}".format(k))
  32. display_value_function(v)
  33.  
  34.  
  35. if __name__ == "__main__":
  36. main()
実行結果は以下である。
iteration size: 88
value funtion:
[[  0.         -13.99330608 -19.99037659 -21.98940765]
 [-13.99330608 -17.99178568 -19.99108113 -19.99118312]
 [-19.99037659 -19.99108113 -17.99247411 -13.99438108]
 [-21.98940765 -19.99118312 -13.99438108   0.        ]]
小数点以下を四捨五入すれば、テキストのk=\inftyの場合の数値と一致する。

Policy Iteration


 この手法は、Policy Evaluationに対しては式(\ref{eq1})を、Policy Improvementに対しては次のベルマン最適方程式を使う。 \begin{equation} V^{*}(s)=\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V^{*}(s^{\prime})\right] \label{eq2} \end{equation}
ただし、これをそのまま使うのではなく、次のようにしてpolicy更新のために利用する。 \begin{equation} \pi(a|s)={\rm arg}\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq3} \end{equation}
右辺を最大にするものがN個ある場合、各行動の実現確率を1/Nとする。最大にしない行動には確率0を割り振る。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
p.77(PDF版でなくハードカバー版)の図4.1の右側の列を再現するコードを以下に示す。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import numpy as np
  4. from common import * # noqa
  5.  
  6.  
  7. def evaluate_policy(rows, cols, v, policy, gamma, reward, threshold):
  8. while True:
  9. delta = 0
  10. for state in state_generator(rows, cols):
  11. if is_terminal_state(state):
  12. continue
  13. tmp = v[state]
  14. v[state] = update_value_function(v, policy, state, gamma, reward)
  15. delta = max(delta, abs(tmp - v[state]))
  16. if delta < threshold:
  17. break
  18.  
  19.  
  20. def update_value_function(v, policy, state, gamma, reward):
  21. action_prob = policy[state]
  22. updated_v = 0
  23. for action_key, prob in zip(ACTIONS.keys(), action_prob):
  24. next_state = state + ACTIONS[action_key]
  25. if not is_on_grid(next_state):
  26. next_state = state
  27. updated_v += prob * (reward + gamma * v[next_state])
  28. return updated_v
  29.  
  30.  
  31. def update_policy(state, reward, gamma, v):
  32. results = {}
  33. for key in ACTIONS:
  34. next_state = ACTIONS[key] + state
  35. if not is_on_grid(next_state):
  36. next_state = state
  37. results[key] = reward + gamma * v[next_state]
  38. max_vs = max(results.values())
  39. return [k for k, val in results.items() if val == max_vs]
  40.  
  41.  
  42. def improve_policy(rows, cols, policy, reward, gamma, v):
  43. is_stable = True
  44. for state in state_generator(rows, cols):
  45. if is_terminal_state(state):
  46. continue
  47. old_action_prob = policy[state].copy()
  48. new_policy = update_policy(state, reward, gamma, v)
  49. overwrite_policy(state, new_policy, policy)
  50. if not np.all(old_action_prob == policy[state]):
  51. is_stable = False
  52. return is_stable
  53.  
  54.  
  55. def main():
  56. v = initialize_value_function(ROWS, COLS)
  57. policy = initialize_policy(ROWS, COLS)
  58. while True:
  59. evaluate_policy(ROWS, COLS, v, policy, GAMMA, REWARD, THRESHOLD)
  60. is_stable = improve_policy(ROWS, COLS, policy, REWARD, GAMMA, v)
  61. if is_stable:
  62. break
  63. display_policy(policy)
  64. display_value_function(v)
  65.  
  66.  
  67. if __name__ == "__main__":
  68. main()
実行結果は以下の通り。
policy ---
0 1 left
0 2 left
0 3 down:left
1 0 up
1 1 up:left
1 2 up:right:down:left
1 3 down
2 0 up
2 1 up:right:down:left
2 2 right:down
2 3 down
3 0 up:right
3 1 right
3 2 right
value function ---
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
policyの出力は、(行番号、列番号、矢印の向き)の順に並んでいる。

Value Iteration


 本手法は式(\ref{eq2})を用いて状態価値関数V(s)を更新する。policy(\pi(a|s))の決定には式(\ref{eq3})を使う。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
V(s)を決めた後の\pi(a|s)の算出には式(\ref{eq3})を使う。Policy Iterationのときと同じ問題に適用するコードは以下の通り。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from common import * # noqa
  4. import policy_iteration
  5.  
  6.  
  7. def update(v, state, gamma, reward):
  8. results = {}
  9. for key in ACTIONS:
  10. next_state = state + ACTIONS[key]
  11. if not is_on_grid(next_state):
  12. next_state = state
  13. results[key] = reward + gamma * v[next_state]
  14. max_vs = max(results.values())
  15. return max_vs
  16.  
  17.  
  18. def make_policy(v, rows, cols, reward, gamma):
  19. policy = initialize_policy(rows, cols)
  20. for state in state_generator(rows, cols):
  21. if is_terminal_state(state):
  22. continue
  23. op = policy_iteration.update_policy(state, reward, gamma, v)
  24. overwrite_policy(state, op, policy)
  25. return policy
  26.  
  27.  
  28. def main():
  29. v = initialize_value_function(ROWS, COLS)
  30. while True:
  31. delta = 0
  32. for state in state_generator(ROWS, COLS):
  33. if is_terminal_state(state):
  34. continue
  35. tmp = v[state]
  36. v[state] = update(v, state, GAMMA, REWARD)
  37. delta = max(delta, abs(tmp - v[state]))
  38. if delta < THRESHOLD:
  39. break
  40. policy = make_policy(v, ROWS, COLS, REWARD, GAMMA)
  41. display_policy(policy)
  42. display_value_function(v)
  43.  
  44.  
  45. if __name__ == "__main__":
  46. main()
実行結果の以下の通り。
policy ---
0 1 left
0 2 left
0 3 down:left
1 0 up
1 1 up:left
1 2 up:right:down:left
1 3 down
2 0 up
2 1 up:right:down:left
2 2 right:down
2 3 down
3 0 up:right
3 1 right
3 2 right
value function ---
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
Policy Iterationのときと同じ結果である。

まとめ


「Policy Iteration」と「Value Iteration」の結果は同じになるが、テキストp.77に掲載されている矢印の向きと少し違う。異なる箇所は、2行3列目と3行2列目のマスである。 間違いが分かる方、ご指摘ください。