Skip to content

Commit dd9af6e

Browse files
committed
원숭이투자 실행 개선
1 parent 7cb2a40 commit dd9af6e

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def reset(self):
6565
self.ratio_hold = 0
6666
self.ratio_portfolio_value = 0
6767

68-
def reset_exploration(self):
69-
self.exploration_base = 0.5 + np.random.rand() / 2
68+
def reset_exploration(self, alpha=None):
69+
if alpha is None:
70+
alpha = np.random.rand() / 2
71+
self.exploration_base = 0.5 + alpha
7072

7173
def set_balance(self, balance):
7274
self.initial_balance = balance

learners.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def run(
287287
self.agent.reset_exploration()
288288
else:
289289
epsilon = start_epsilon
290+
self.agent.reset_exploration(alpha=0)
290291

291292
while True:
292293
# 샘플 생성

main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
parser.add_argument('--stock_code', nargs='+')
1515
parser.add_argument('--ver', choices=['v1', 'v2'], default='v2')
1616
parser.add_argument('--rl_method',
17-
choices=['dqn', 'pg', 'ac', 'a2c', 'a3c'])
17+
choices=['dqn', 'pg', 'ac', 'a2c', 'a3c', 'monkey'])
1818
parser.add_argument('--net',
19-
choices=['dnn', 'lstm', 'cnn'], default='dnn')
19+
choices=['dnn', 'lstm', 'cnn', 'monkey'], default='dnn')
2020
parser.add_argument('--num_steps', type=int, default=1)
2121
parser.add_argument('--lr', type=float, default=0.01)
2222
parser.add_argument('--discount_factor', type=float, default=0.9)
@@ -63,8 +63,8 @@
6363

6464
# 로그, Keras Backend 설정을 먼저하고 RLTrader 모듈들을 이후에 임포트해야 함
6565
from agent import Agent
66-
from learners import DQNLearner, PolicyGradientLearner, \
67-
ActorCriticLearner, A2CLearner, A3CLearner
66+
from learners import ReinforcementLearner, DQNLearner, \
67+
PolicyGradientLearner, ActorCriticLearner, A2CLearner, A3CLearner
6868

6969
# 모델 경로 준비
7070
value_network_path = ''
@@ -130,6 +130,13 @@
130130
learner = A2CLearner(**{**common_params,
131131
'value_network_path': value_network_path,
132132
'policy_network_path': policy_network_path})
133+
elif args.rl_method == 'monkey':
134+
args.net = args.rl_method
135+
args.num_epoches = 1
136+
args.discount_factor = None
137+
args.start_epsilon = 1
138+
args.learning = False
139+
learner = ReinforcementLearner(**common_params)
133140
if learner is not None:
134141
learner.run(balance=args.balance,
135142
num_epoches=args.num_epoches,

networks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def set_session(sess): pass
2323
from tensorflow.keras.backend import set_session
2424
import tensorflow as tf
2525
graph = tf.get_default_graph()
26-
sess = tf.compat.v1.Session()
26+
config = tf.compat.v1.ConfigProto()
27+
config.gpu_options.allow_growth = True
28+
sess = tf.compat.v1.Session(config=config)
2729
elif os.environ['KERAS_BACKEND'] == 'plaidml.keras.backend':
2830
from keras.models import Model
2931
from keras.layers import Input, Dense, LSTM, Conv2D, \

0 commit comments

Comments
 (0)