SR-ARPOD/stellar/visualizer/visualizer_all.py
2026-04-01 22:48:53 +08:00

406 lines
16 KiB
Python

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import style
import matplotlib
from matplotlib import cm
from mpl_toolkits import mplot3d
import numpy as np
import os
from functools import partial
from time import sleep
from itertools import cycle
import matplotlib.colors as mcolors
from matplotlib.pyplot import cycler
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
#import matplotlib.cm
def get_cycle(cmap, N=None, use_index="auto"):
if isinstance(cmap, str):
if use_index == "auto":
if cmap in ['Pastel1', 'Pastel2', 'Paired', 'Accent',
'Dark2', 'Set1', 'Set2', 'Set3',
'tab10', 'tab20', 'tab20b', 'tab20c']:
use_index=True
else:
use_index=False
cmap = matplotlib.cm.get_cmap(cmap)
if not N:
N = cmap.N
if use_index=="auto":
if cmap.N > 100:
use_index=False
elif isinstance(cmap, LinearSegmentedColormap):
use_index=False
elif isinstance(cmap, ListedColormap):
use_index=True
if use_index:
ind = np.arange(int(N)) % cmap.N
return cycler("color",cmap(ind))
else:
colors = cmap(np.linspace(0,1,N))
return cycler("color",colors)
def write2text(chaser, data_dir, file_name, step, is_velocities=False):
dir_path = os.getcwd()
data_path = os.path.join(dir_path, data_dir)
#file_folder = f'vbar-vel{len(os.listdir(data_path))}'
"""
data_dir : folder to store all velocities during all episodes
file_name : text file containing all velocities for an episode
"""
#file_path = os.path.join(data_path, file_folder)
#file_path = os.path.join(file_path, file_name)
file_path=os.path.join(data_path, file_name)
#file name is changed in test loop
if os.path.exists(data_path) == False:
os.makedirs(data_path)
"""
file_id = len(os.listdir(data_path))
file_name = f'chaser{file_id}.txt'
"""
if os.path.exists(file_path) == False:
with open(file_path, 'w') as f:
#f.write('0,0,0\n')
pass
state_arr = np.asarray(chaser.get_state_trace())
print(f'state arr shape {state_arr.shape}')
print('========================')
print('state arr')
print(state_arr)
print('========================')
if is_velocities:
pos_arr = state_arr[step-1:, 3:6]
else:
pos_arr = state_arr[step-1:, 0:3]
lines2write = ['\n'+str(pos[0])+','+str(pos[1])+','+str(pos[2]) for pos in pos_arr]
print(f'WRITING TO {file_path}')
print('----------------------')
print('DATA WRITTING')
print(lines2write)
file = open(file_path, 'a')
file.writelines(lines2write)
file.close()
class render_visual:
def __init__(self, data_dir = 'runs', only_inital = False, terminal_plots = False, run_file = None):
style.use('default')
matplotlib.use('WebAgg')
#matplotlib.use('pgf')
dir_path = os.getcwd()
data_path = os.path.join(dir_path, data_dir)
self.data_path = data_path
self.only_inital = only_inital
self.term_plots = terminal_plots
self.dock_point = np.array([0, 60, 0])
self.theta = 60.0
self.center_termin = False
self.run_file = run_file
self.show_trajectories = True
if os.path.exists(data_path) == False:
os.makedirs(data_path)
"""
file_id = len(os.listdir(data_path))
file_name = f'chaser{file_id}.txt'
file_path = os.path.join(data_path, file_name)
with open(file_path, 'w') as f:
#f.write('0,0,0')
pass
"""
self.fig = plt.figure(figsize = (5,5), dpi=100)
self.ax = plt.axes(projection='3d')
#plt.rcParams["figure.dpi"] = 1000
#self.ax.set_zlim([-1200, 1200])
#self.ax.set_xlim([-1200, 1200])
#self.ax.set_ylim([-1200, 1200])
self.ax.set_zlim([-800, 800])
self.ax.set_xlim([-800, 800])
self.ax.set_ylim([-100, 1200])
"""
Plot target (cylinder) and LOS (cone)
"""
x_center, y_center, radius, height = 0, 0, 60, 300
Xc, Yc, Zc = self.data_for_cylinder_along_z(x_center, y_center, radius, height)
if not only_inital:
self.ax.plot_surface(Xc, Yc, Zc, alpha=0.5)
else:
plot_terminal = input('Terminal visualization [y/n]: ')
center_termin = input('center terminal plot? [y/n]')
if plot_terminal == 'y':
self.term_plots = True
if center_termin == 'y':
self.center_termin = True
else:
pass
print('only initial positions no target visualization')
opt_los = input("Do you want to plot original LOS cone? (y/n): ")
opt_view = input("Do you want view [1] or view [2] (1/2): ")
opt_trajectory = input("Do you want to plot trajectories? (y/n): ")
if opt_trajectory == 'y':
self.show_trajectories = True
elif opt_trajectory == 'n':
self.show_trajectories = False
else:
raise ValueError('Invalid input')
if opt_los == 'y':
cone_length = 100
print(f'plotting original LOS cone length {cone_length}')
#X_los, Y_los, Z_los = self.data_for_cone_along_y(self.dock_point[0], self.dock_point[2], self.theta, self.dock_point[1], length=cone_length)
X_los, Y_los, Z_los = self.parametric_cone_along_y(self.dock_point[0], self.dock_point[2], self.theta, self.dock_point[1], length=cone_length)
else:
cone_length = 800
print(f'plotting LOS proposed with length {self.dock_point[1]}')
#X_los, Y_los, Z_los = self.data_for_cone_along_y(self.dock_point[0], self.dock_point[2], self.theta, self.dock_point[1])
X_los, Y_los, Z_los = self.parametric_cone_along_y(self.dock_point[0], self.dock_point[2], self.theta, self.dock_point[1], length=cone_length)
#self.ax.plot_surface(X_los, Y_los, Z_los, alpha=0.5)
x = np.linspace(-1.0, 1.0, len(X_los)**2)
if not only_inital:
#self.ax.scatter(X_los, Y_los, Z_los, s=0.1, c=x[::-1], cmap=cm.coolwarm, vmin=-1.0, vmax=1.0, alpha=0.8, antialiased=False, linewidths=0.8)
#my_cmap = plt.get_cmap('coolwarm')
cmap = cm.ScalarMappable(cmap='coolwarm_r')
cmap.set_array(x[::-1])
#self.ax.plot_surface(X_los, Y_los, Z_los, cmap=cmap.cmap, alpha=0.4, antialiased=False)
face_colors = self.color_fade_with_y(X_los, Y_los, Z_los, cmap='coolwarm')
#self.ax.contour3D(X_los, Y_los, Z_los, 100, cmap=my_cmap, linewidths=3, alpha=0.1)
else:
print('only initial positions no LOS cone')
#X_slow, Y_slow, Z_slow = self.data_for_slowzone()
#self.ax.plot_surface(X_slow, Y_slow, Z_slow, color='r', alpha = 0.25)
#x=[0,0]
#y=[60,860]
#z=[0,0]
#self.ax.plot3D(x, y, z, 'green')
#self.ax.quiver(0, 60, 0, 0, 860, 0, headlength=1, headwidth=1, headaxislength=1, color=['black'], pivot='tail')
#self.ax.view_init(elev=25, azim=60, roll=0)
if opt_view == '1':
self.ax.view_init(elev=25, azim=60, roll=0)
elif opt_view == '2':
self.ax.view_init(elev=10, azim=50, roll=0)
else:
raise ValueError('Invalid view option')
plt.draw()
def color_fade_with_y(self, X, Y, Z, cmap='coolwarm'):
# Create a figure and axes
# Create the surface plot
# Normalize Y to the range [0, 1] for color mapping
y_normalized = (Y - np.min(Y)) / (np.max(Y) - np.min(Y))
y_normalized = y_normalized * 2.0 - 1
face_colors = plt.cm.coolwarm(y_normalized)
surf = self.ax.plot_surface(X, Y, Z, cmap=cmap, facecolors=face_colors, rstride=5, cstride=5, linewidth=0.7, alpha=0.3, antialiased=False)
#surf = ax.plot_surface(X, Y, Z, cmap=cmap, linewidth=0, antialiased=False)
# Obtain the face colors from the surface plot
# Add a colorbar
#fig.colorbar(surf, shrink=0.5, aspect=5)
# Return the modified face colors
return face_colors
def data_for_slowzone(self):
u = np.linspace(0,2*np.pi, 100)
v = np.linspace(0, np.pi, 100)
r = 500 #500m radius
x_cen, y_cen, z_cen = self.dock_point[0], self.dock_point[1], self.dock_point[2]
x = x_cen + (r * np.outer(np.cos(u), np.sin(v)) )
y = y_cen + (r * np.outer(np.sin(u), np.sin(v)) )
z = z_cen + (r * np.outer(np.ones(np.size(u)), np.cos(v)) )
return x, y, z
def data_for_cylinder_along_z(self, center_x, center_y, radius, height_z):
#height_z
z = np.linspace((height_z/2)*-1, height_z/2, 80)
theta = np.linspace(0, 2*np.pi, 50)
theta_grid, z_grid=np.meshgrid(theta, z)
x_grid = radius*np.cos(theta_grid) + center_x
y_grid = radius*np.sin(theta_grid) + center_y
return x_grid,y_grid,z_grid
def data_for_cone_along_y(self, center_x, center_z, theta, height_y, length = 800):
max_y = height_y + length
print(f'height_y {height_y} max_y {max_y}')
print(f'center_x {center_x} center_z {center_z} theta {theta}')
theta = 30.0
h = length*(1/np.cos(theta))
r = np.sqrt(h**2.0 - length**2.0)
min_iterations = max_y - height_y
Y = np.linspace(height_y-1.5, max_y, 10000)
X = center_x + Y * np.sin(theta*Y)
Z = center_z + Y * np.cos(theta*Y)
return X, Y, Z
def parametric_cone_along_y(self, center_x, center_z, theta, height_y, length = 800):
h = length*(1/np.cos(theta))
radius = length / np.sqrt(3)
alpha = np.linspace(0, 2.*np.pi, 75) # angle around the cone
r = np.linspace(0, radius, 75) # radius from the center of the cone
Alpha, R = np.meshgrid(alpha, r) # create a grid of (theta, r) points
# Convert polar (cone) coordinates to cartesian coordinates
X = R * np.cos(Alpha)
Z = R * np.sin(Alpha) # shifting the cone to have vertex at y = 60
Y = np.tan(np.radians(theta)) * np.sqrt(X**2 + Z**2) + 60 # height of the cone from its vertex
return X, Y, Z
def animate(self, i, file_name = None):
cycol = cycle('rgbycmk')
for run_file in os.listdir(self.data_path):
print(f'plotting {run_file}')
data = os.path.join(self.data_path, run_file)
if os.path.exists(data) == False:
#file_path = os.path.join(data_path, file_name)
with open(data, 'w') as f:
#f.write('0,0,0')
pass
if not self.show_trajectories:
break
if self.run_file != None:
data = os.path.join(self.data_path, self.run_file)
with open(data) as f:
lines = f.readlines()
xs, ys, zs = np.array([]), np.array([]), np.array([])
for idx, line in enumerate(lines):
line = line.strip()
if len(line) > 1:
x, y, z = line.split(',')
x, y, z = np.float64(x), np.float64(y), np.float64(z)
xs = np.append(xs, [x])
ys = np.append(ys, [y])
zs = np.append(zs, [z])
if idx > 1 and self.only_inital and self.term_plots == False:
self.ax.set_xlim([-400, 400])
self.ax.set_ylim([400, 1200])
self.ax.set_zlim([-400, 400])
break
elif self.term_plots == True:
#self.ax.lab
if self.center_termin != True:
self.ax.set_xlim([-400, 400])
self.ax.set_ylim([400, 1200])
self.ax.set_zlim([-400, 400])
else:
self.ax.set_xlim([-500, 500])
self.ax.set_ylim([-440, 560])
self.ax.set_zlim([-500, 500])
if self.term_plots == False:
self.ax.scatter(xs[1], ys[1], zs[1], marker='o', c='purple', s=1.5, alpha=0.99)
#fade = np.linspace(-1.0, 1.0, len(xs))
#self.ax.scatter(xs, ys, zs, marker='o', s=0.05, c=next(cycol), alpha=0.9)
if self.term_plots == False:
if self.run_file != None:
self.ax.plot(xs, ys, zs, c='g', linewidth=0.9, alpha=0.95, markevery=15, antialiased=False)
else:
self.ax.plot(xs[:i], ys[:i], zs[:i], c=next(cycol), linewidth=0.7, alpha=0.5, markevery=15, antialiased=True)
else:
term_scatter = self.ax.scatter(xs[-1], ys[-1], zs[-1], marker='o', c='r', s=0.5)
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x_sphere = 20 * np.outer(np.cos(u), np.sin(v))
y_sphere = (20 * np.outer(np.sin(u), np.sin(v))) + 60
z_zphere = 20 * np.outer(np.ones(np.size(u)), np.cos(v))
dock_scatter = self.ax.scatter([0], [60], [0], marker='o', c='r', s=1.5)
#legend1 = self.ax.legend(*term_scatter.legend_elements(), loc="upper right", title="State Legend")
#self.ax.add_artist(legend1)
if False:
legend_elements = [mplot3d.art3d.Line3D([0], [0], [0], marker='o', color='b', label='Final Position', markersize=8),
mplot3d.art3d.Line3D([0], [0], [0], marker='o', color='r', label='Docking Port', markersize=8)]
self.ax.legend(handles=legend_elements, loc='upper right', fontsize=8)
#self.ax.plot_surface(x_sphere, y_sphere, z_zphere, rstride=4, cstride=4, color='g', alpha=0.2)
if self.run_file != None:
self.ax.set_xlim([-400, 400])
self.ax.set_ylim([-200, 1000])
self.ax.set_zlim([-400, 400])
break
self.ax.set_xlabel('X [m]', fontsize=10, labelpad=5)
self.ax.set_ylabel('Y [m]', fontsize=10, labelpad=5)
self.ax.set_zlabel('Z [m]', fontsize=10, labelpad=5)
#self.ax.legend()
if self.center_termin == True:
self.ax.set_xticks(np.arange(-500, 501, 120))
self.ax.set_yticks(np.arange(-440, 561, 120))
self.ax.set_zticks(np.arange(-500, 501, 120))
# Set tick font size
for label in (self.ax.get_xticklabels() + self.ax.get_yticklabels() + self.ax.get_zticklabels()):
label.set_fontsize(8)
#self.ax.legend()
#self.ax.tick_params(axis='both', which='major', labelsize=5)
#self.ax.tick_params(axis='both', which='minor', labelsize=5)
plt.draw()
sleep(1.0)
if i > 125:
return
def render_animation(self, file_name = None):
"""
fig = plt.figure(figsize = (8,8))
ax = plt.axes(projection='3d')
ax.set_zlim([-1000, 1000])
ax.set_xlim([-1000, 1000])
ax.set_ylim([-1000, 1000])
"""
#plt.ion()
#data = os.path.join(self.data_path, file_name)
ani = animation.FuncAnimation(self.fig, partial(self.animate, file_name=file_name), repeat=False, cache_frame_data=True, interval=100)
#print(ani.to_jshtml())
#plt.ion()
plt.show(block=False)
#plt.savefig('myfigure.eps', format='eps', dpi=300)
#plt.show()
#sleep(1.0)
def save(self):
self.fig.show()
self.fig.savefig("1003dvbar.eps", format="eps",transparent=True)
#plt.savefig('1003dvbar.png')