import gymnasium as gym
from stable_baselines3 import PPO  # WORKS GREAT! even better 8sec done!
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.monitor import Monitor
import matplotlib.pyplot as plt

render_mode = None
# render_mode = "human"  # visualization
env = Monitor(gym.make("Pendulum-v1", render_mode=render_mode), "./logs/")

try:
	algo = PPO.load("ppo_pendulum_model", env=env)
except:
	print("Loading model failed. Starting from scratch.")
	algo = PPO("MlpPolicy", env, verbose=0) # based on PG
algo.learn(total_timesteps=2000000) # 1 step = one action! (not episode!)
algo.save("ppo_pendulum_model")

# plotting
x, y = ts2xy(load_results("./logs/"), 'timesteps')
plt.plot(x, y)
plt.xlabel('Timesteps')
plt.ylabel('Rewards')
plt.grid()
plt.show()
