import numpy as np
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import mpl_toolkits.mplot3d.art3d as art3d
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.spatial.transform import Rotation

points = np.array([[-1, -1, -1],
                  [1, -1, -1 ],
                  [1, 1, -1],
                  [-1, 1, -1],
                  [-1, -1, 1],
                  [1, -1, 1 ],
                  [1, 1, 1],
                  [-1, 1, 1]])
Z = points
#Z = 10.0*Z
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
sc = [-1,1]
X, Y = np.meshgrid(sc, sc)
ax.scatter3D(Z[:, 0], Z[:, 1], Z[:, 2],color="red")
verts = [[Z[0],Z[1],Z[2],Z[3]],
 [Z[4],Z[5],Z[6],Z[7]],
 [Z[0],Z[1],Z[5],Z[4]],
 [Z[2],Z[3],Z[7],Z[6]],
 [Z[1],Z[2],Z[6],Z[5]],
 [Z[4],Z[7],Z[3],Z[0]]]

ax.add_collection3d(Poly3DCollection(verts, facecolors = 'cyan', linewidths=1, edgecolors='red', alpha=0.2))

r= np.sqrt(3)
theta_2_0 = np.linspace(0, 2*np.pi, 100) 
theta_1_0 = np.linspace(0, 2*np.pi, 100) 

alpha = np.pi/2-np.arcsin(1/np.sqrt(3))

theta_1, theta_2 = np.meshgrid(theta_1_0, theta_2_0)
x2 = np.cos(theta_2)*np.sin(theta_1) * r
y2 = np.sin(theta_2)*np.sin(theta_1) * r
z2 = np.cos(theta_1) * r
ax.plot_surface(x2,y2,z2, color='b',alpha=0.2)

alpha2=alpha
rot = Rotation.from_rotvec(np.array([np.pi/8,0,0]))
theta_3_0 = np.linspace(0, 2*np.pi, 100)
theta_4_0 = 0
x_c = np.cos(theta_4_0) * np.sin(theta_3_0) * r 
y_c = np.sin(theta_4_0) * np.sin(theta_3_0) * r 
z_c = np.cos(theta_3_0) * r 

x_d = x_c.reshape(100,1)
y_d = y_c.reshape(100,1)
z_d = z_c.reshape(100,1)
mm = np.r_["2,3,1", x_d, y_d , z_d]
mm2 = mm.reshape(100,3)

dd = rot.apply(mm2)

e=rot.apply(dd)
x_t=[]
y_t=[]
z_t=[]
for j in range(3):
    for i in range(100):
        if j == 0:
            x_t.append(e[i][j])
        if j == 1:
            y_t.append(e[i][j])
        if j == 2:
            z_t.append(e[i][j])

print("x:",x_t,"y:",y_t,"z:",z_t)
ax.plot(x_t, y_t, z_t, color="r")

ax.set_xlabel('x axis')
ax.set_ylabel('y axis')
ax.set_zlabel('z axis')
ax.set_xlim(-2,2)
ax.set_ylim(2,-2)
ax.set_zlim(-2,2)

plt.show()