Processing math: 100%

2019年2月24日日曜日

Reinforcement Learning An Introduction 第2版 4章

はじめに


 テキストReinforcement Learning An Introduction 第2版の4章で紹介されている以下3つの手法をpythonで実装する。
  1. Iterative Policy Evaluation
  2. Policy Iteration
  3. Value Iteration

コードの場所


  今回の全コードはここにある。  

Iterative Policy Evaluation


 この手法は状態価値関数V(s)についての次のベルマン方程式を使う。 \begin{equation} V(s)=\sum_{s^{\prime},a} P(s^{\prime}|s,a)\pi(a|s)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq1} \end{equation} この式の導出はここで行った。テキストに掲載されているIterative Policy Evaluationのアルゴリズムは以下の通りである(テキストPDF版から引用した)。
policy(\pi(a|s))を固定し、式(\ref{eq1})を反復法で解いているだけである。p.77(PDF版でなくハードカバー版)の図4.1の左側の列を再現するコードを以下に示す。 等確率(\pi(a|s)=0.25)で上下左右を選択するpolicyを採用したものである。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from common import * # noqa
  4.  
  5.  
  6. def update(v, state, gamma, reward):
  7. updated_v = 0.0
  8. for action in ACTIONS.values():
  9. next_state = state + action
  10. if not is_on_grid(next_state):
  11. next_state = state
  12. updated_v += 0.25 * (reward + gamma * v[next_state])
  13. return updated_v
  14.  
  15.  
  16. def main():
  17. v = initialize_value_function(ROWS, COLS)
  18. k = 0
  19. while True:
  20. delta = 0
  21. for state in state_generator(ROWS, COLS):
  22. if is_terminal_state(state):
  23. continue
  24. tmp = v[state]
  25. v[state] = update(v, state, GAMMA, REWARD)
  26. delta = max(delta, abs(tmp - v[state]))
  27. k += 1
  28. if delta < THRESHOLD:
  29. break
  30.  
  31. print("iteration size: {}".format(k))
  32. display_value_function(v)
  33.  
  34.  
  35. if __name__ == "__main__":
  36. main()
実行結果は以下である。
iteration size: 88
value funtion:
[[  0.         -13.99330608 -19.99037659 -21.98940765]
 [-13.99330608 -17.99178568 -19.99108113 -19.99118312]
 [-19.99037659 -19.99108113 -17.99247411 -13.99438108]
 [-21.98940765 -19.99118312 -13.99438108   0.        ]]
小数点以下を四捨五入すれば、テキストのk=\inftyの場合の数値と一致する。

Policy Iteration


 この手法は、Policy Evaluationに対しては式(\ref{eq1})を、Policy Improvementに対しては次のベルマン最適方程式を使う。 \begin{equation} V^{*}(s)=\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V^{*}(s^{\prime})\right] \label{eq2} \end{equation} ただし、これをそのまま使うのではなく、次のようにしてpolicy更新のために利用する。 \begin{equation} \pi(a|s)={\rm arg}\max_{a}\sum_{s^{\prime}} P(s^{\prime}|s,a)\left[r(s,a,s^{\prime})+\gamma V(s^{\prime})\right] \label{eq3} \end{equation} 右辺を最大にするものがN個ある場合、各行動の実現確率を1/Nとする。最大にしない行動には確率0を割り振る。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
p.77(PDF版でなくハードカバー版)の図4.1の右側の列を再現するコードを以下に示す。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import numpy as np
  4. from common import * # noqa
  5.  
  6.  
  7. def evaluate_policy(rows, cols, v, policy, gamma, reward, threshold):
  8. while True:
  9. delta = 0
  10. for state in state_generator(rows, cols):
  11. if is_terminal_state(state):
  12. continue
  13. tmp = v[state]
  14. v[state] = update_value_function(v, policy, state, gamma, reward)
  15. delta = max(delta, abs(tmp - v[state]))
  16. if delta < threshold:
  17. break
  18.  
  19.  
  20. def update_value_function(v, policy, state, gamma, reward):
  21. action_prob = policy[state]
  22. updated_v = 0
  23. for action_key, prob in zip(ACTIONS.keys(), action_prob):
  24. next_state = state + ACTIONS[action_key]
  25. if not is_on_grid(next_state):
  26. next_state = state
  27. updated_v += prob * (reward + gamma * v[next_state])
  28. return updated_v
  29.  
  30.  
  31. def update_policy(state, reward, gamma, v):
  32. results = {}
  33. for key in ACTIONS:
  34. next_state = ACTIONS[key] + state
  35. if not is_on_grid(next_state):
  36. next_state = state
  37. results[key] = reward + gamma * v[next_state]
  38. max_vs = max(results.values())
  39. return [k for k, val in results.items() if val == max_vs]
  40.  
  41.  
  42. def improve_policy(rows, cols, policy, reward, gamma, v):
  43. is_stable = True
  44. for state in state_generator(rows, cols):
  45. if is_terminal_state(state):
  46. continue
  47. old_action_prob = policy[state].copy()
  48. new_policy = update_policy(state, reward, gamma, v)
  49. overwrite_policy(state, new_policy, policy)
  50. if not np.all(old_action_prob == policy[state]):
  51. is_stable = False
  52. return is_stable
  53.  
  54.  
  55. def main():
  56. v = initialize_value_function(ROWS, COLS)
  57. policy = initialize_policy(ROWS, COLS)
  58. while True:
  59. evaluate_policy(ROWS, COLS, v, policy, GAMMA, REWARD, THRESHOLD)
  60. is_stable = improve_policy(ROWS, COLS, policy, REWARD, GAMMA, v)
  61. if is_stable:
  62. break
  63. display_policy(policy)
  64. display_value_function(v)
  65.  
  66.  
  67. if __name__ == "__main__":
  68. main()
実行結果は以下の通り。
policy ---
0 1 left
0 2 left
0 3 down:left
1 0 up
1 1 up:left
1 2 up:right:down:left
1 3 down
2 0 up
2 1 up:right:down:left
2 2 right:down
2 3 down
3 0 up:right
3 1 right
3 2 right
value function ---
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
policyの出力は、(行番号、列番号、矢印の向き)の順に並んでいる。

Value Iteration


 本手法は式(\ref{eq2})を用いて状態価値関数V(s)を更新する。policy(\pi(a|s))の決定には式(\ref{eq3})を使う。 テキストにあるアルゴリズムは以下の通りである(PDF版から引用した)。
V(s)を決めた後の\pi(a|s)の算出には式(\ref{eq3})を使う。Policy Iterationのときと同じ問題に適用するコードは以下の通り。
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from common import * # noqa
  4. import policy_iteration
  5.  
  6.  
  7. def update(v, state, gamma, reward):
  8. results = {}
  9. for key in ACTIONS:
  10. next_state = state + ACTIONS[key]
  11. if not is_on_grid(next_state):
  12. next_state = state
  13. results[key] = reward + gamma * v[next_state]
  14. max_vs = max(results.values())
  15. return max_vs
  16.  
  17.  
  18. def make_policy(v, rows, cols, reward, gamma):
  19. policy = initialize_policy(rows, cols)
  20. for state in state_generator(rows, cols):
  21. if is_terminal_state(state):
  22. continue
  23. op = policy_iteration.update_policy(state, reward, gamma, v)
  24. overwrite_policy(state, op, policy)
  25. return policy
  26.  
  27.  
  28. def main():
  29. v = initialize_value_function(ROWS, COLS)
  30. while True:
  31. delta = 0
  32. for state in state_generator(ROWS, COLS):
  33. if is_terminal_state(state):
  34. continue
  35. tmp = v[state]
  36. v[state] = update(v, state, GAMMA, REWARD)
  37. delta = max(delta, abs(tmp - v[state]))
  38. if delta < THRESHOLD:
  39. break
  40. policy = make_policy(v, ROWS, COLS, REWARD, GAMMA)
  41. display_policy(policy)
  42. display_value_function(v)
  43.  
  44.  
  45. if __name__ == "__main__":
  46. main()
実行結果の以下の通り。
policy ---
0 1 left
0 2 left
0 3 down:left
1 0 up
1 1 up:left
1 2 up:right:down:left
1 3 down
2 0 up
2 1 up:right:down:left
2 2 right:down
2 3 down
3 0 up:right
3 1 right
3 2 right
value function ---
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
Policy Iterationのときと同じ結果である。

まとめ


「Policy Iteration」と「Value Iteration」の結果は同じになるが、テキストp.77に掲載されている矢印の向きと少し違う。異なる箇所は、2行3列目と3行2列目のマスである。 間違いが分かる方、ご指摘ください。