import numpy as np
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import mpl_toolkits.mplot3d.art3d as art3d
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.spatial.transform import Rotation
plt.rcParams['font.family'] = 'IPAmjMincho'

def descartes_3d(ax, ran_x, ran_y, ran_z, ax_title, 
                 x_label = "x軸", y_label = "y軸", z_label="z軸"):
    ax.set_xlabel(x_label, fontsize = 12)
    ax.set_ylabel(y_label, fontsize = 12)
    ax.set_zlabel(z_label, fontsize = 12)
    ax.set_xlim(ran_x[0], ran_x[1])
    ax.set_ylim(ran_y[0], ran_y[1])
    ax.set_zlim(ran_z[0], ran_z[1])
    ax.set_title(ax_title, fontsize = 16)
    ax.grid()

points = np.array([[-1, -1, -1],
                  [1, -1, -1 ],
                  [1, 1, -1],
                  [-1, 1, -1],
                  [-1, -1, 1],
                  [1, -1, 1 ],
                  [1, 1, 1],
                  [-1, 1, 1]])
Z = points

fig = plt.figure(figsize = (6, 6))
ax = fig.add_subplot(111, projection='3d')
title1 = "Tokyo University 2023 Math 第6問(1) NO4"

descartes_3d(ax, [-2, 2], [-2, 2], [-2, 2], title1)

sc = [-1,1]
X, Y = np.meshgrid(sc, sc)
ax.scatter3D(Z[:, 0], Z[:, 1], Z[:, 2],color="red")
verts = [[Z[0],Z[1],Z[2],Z[3]],
 [Z[4],Z[5],Z[6],Z[7]],
 [Z[0],Z[1],Z[5],Z[4]],
 [Z[2],Z[3],Z[7],Z[6]],
 [Z[1],Z[2],Z[6],Z[5]],
 [Z[4],Z[7],Z[3],Z[0]]]

ax.add_collection3d(Poly3DCollection(verts, facecolors = 'cyan', linewidths=0.3, edgecolors='red', alpha=0.3))

alpha = np.pi/2-np.arcsin(1/np.sqrt(3))
r= np.sqrt(3) # 半径を指定
theta_2_0 = np.linspace(0, 2*np.pi, 100) 
theta_1_0 = np.linspace(0, alpha, 100) 

theta_1, theta_2 = np.meshgrid(theta_1_0, theta_2_0) # 2次元配列に変換
x2 = np.cos(theta_2)*np.sin(theta_1) * r # xの極座標表示
y2 = np.sin(theta_2)*np.sin(theta_1) * r # yの極座標表示
z2 = np.cos(theta_1) * r # zの極座標表示

ax.plot_surface(x2,y2,z2, color='blue',alpha=0.8) # 球を3次元空間に表示

alpha2=alpha

theta_3_0 = np.linspace(-(np.pi/2-alpha2), (np.pi/2-alpha2), 100)
theta_4_0 = 0
x_c = np.cos(theta_4_0) * np.sin(theta_3_0) * r 
y_c = np.sin(theta_4_0) * np.sin(theta_3_0) * r 
z_c = np.cos(theta_3_0) * r 

x_d = x_c.reshape(100,1)
y_d = y_c.reshape(100,1)
z_d = z_c.reshape(100,1)
mm = np.r_["2,3,1", x_d, y_d , z_d]

mm2 = mm.reshape(100,3)

rot = Rotation.from_rotvec(np.array([np.pi/4,0,0]))
e = rot.apply(mm2)

for k in range(4):
    x_t=[]
    y_t=[]
    z_t=[]
    rot = Rotation.from_rotvec(np.array([0,0,np.pi/2*k]))
    e = rot.apply(e)
    for j in range(3):
        for i in range(100):
            if j == 0:
                x_t.append(e[i][j])
            if j == 1:
                y_t.append(e[i][j])
            if j == 2:
                z_t.append(e[i][j])
    
    ax.plot(x_t, y_t, z_t, color="r",lw=3)

ax.set_xlabel('x 軸')
ax.set_ylabel('y 軸')
ax.set_zlabel('z 軸')
ax.set_xlim(-2,2)
ax.set_ylim(2,-2)
ax.set_zlim(-2,2)

plt.show()