import gymnasium as gym
import torch
from random import random

# Hyperparameters
episodes = 10

render_mode = "human"
# render_mode = None
env = gym.make('CartPole-v1', render_mode=render_mode)  # pole with visualization

state_dimension = 4  # pole state: (position, velocity, angle, angular-velocity)
action_space_dims = 2 # left / right

def policy_random(state):
	return env.action_space.sample()  # Random action 0 or 1

def policy(state, step):
	position, velocity, angle, angular_velocity = state  # hole Position und Geschwindigkeit aus state vector/tensor
	left = 0
	right = 1
	action = left  # todo better
	return action

gewinn_liste = []
for step in range(episodes):
	state = torch.tensor(env.reset()[0])
	done = False
	prev_state = None
	gain = 0
	while not done: # play
			action = policy(state, step)
			state, reward, terminated, truncated, info = env.step(action)
			state = torch.tensor(state)
			prev_state = state
			done = terminated or truncated  # pole is down or time is up!
			gain += reward
	gewinn_liste.append(gain) # for plotting

bester_versuch=max(gewinn_liste)
print("Bester Versuch:", bester_versuch)

# plot
import matplotlib.pyplot as plt
plt.plot(gewinn_liste)
plt.xlabel("Versuch#")
plt.ylabel("Geschaffte Schritte")
plt.show()