import gymnasium as gym
import torch
from random import random

# Hyperparameters
episodes = 100

# render_mode = "human"
render_mode = None
env = gym.make('CartPole-v1', render_mode=render_mode)  # pole with visualization

state_dimension = 4  # pole state: (position, velocity, angle, angular-velocity)
action_space_dims = 2 # left / right

def policy_random(state): # Average ≈ 22
	return env.action_space.sample()  # Random action 0 or 1

last = 0
def policy_zitter(state):  # Average ≈ 37
	global last
	if last == 0:
		last = 1
	else:
		last = 0
	return last

def policy_gegenlenken(state): # Average ≈ 42
	# print(state.shape, state)
	position, velocity, angle, angular_velocity = state # hole Position und Geschwindigkeit aus state vector/tensor
	left = 0
	right = 1
	if angle < 0:
		action = left
	else:
		action = right
	return action

def policy_smart(state): # Average ≈ 476 kudos to Julian
	position, velocity, angle, angular_velocity = state # hole Position und Geschwindigkeit aus state vector/tensor
	left = 0
	right = 1
	if angle < 0:
		if angular_velocity > 1.6:
			action = right
		else:
			action = left
	else:
		if angular_velocity < -1.6:
			action = left
		else:
			action = right
	return action

def policy(state, step):
	return policy_smart(state)
	# return policy_random(state)
	# return policy_gegenlenken(state)
	# return policy_zitter(state)

gewinn_liste = []
for step in range(episodes):
	state = torch.tensor(env.reset()[0])
	done = False
	prev_state = None
	gain = 0
	while not done: # play
			action = policy(state, step)
			state, reward, terminated, truncated, info = env.step(action)
			state = torch.tensor(state)
			prev_state = state
			done = terminated or truncated  # pole is down or time is up!
			gain += reward
	gewinn_liste.append(gain) # for plotting

bester_versuch=max(gewinn_liste)
print("Bester Versuch:", bester_versuch)

average=sum(gewinn_liste)/len(gewinn_liste)
print("Average:", average)

# plot
import matplotlib.pyplot as plt
plt.plot(gewinn_liste)
plt.xlabel("Versuch#")
plt.ylabel("Geschaffte Schritte")
plt.show()