import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def descartes_3d(ax, ran_x, ran_y, ran_z, ax_title, x_label = "x", y_label = "y", z_label="z"):
    ax.set_xlabel(x_label, fontsize = 12)
    ax.set_ylabel(y_label, fontsize = 12)
    ax.set_zlabel(z_label, fontsize = 12)
    ax.set_xlim(ran_x[0], ran_x[1])
    ax.set_ylim(ran_y[0], ran_y[1])
    ax.set_zlim(ran_z[0], ran_z[1])
    ax.set_title(ax_title, fontsize = 16)
    ax.grid()

def connect_line_3d(point_a,point_b,sp="-",color="black"):
    x,y,z=[],[],[]
    x.append(point_a[0])
    x.append(point_b[0])
    y.append(point_a[1])
    y.append(point_b[1])
    z.append(point_a[2])
    z.append(point_b[2])
    ax.plot(x,y,z,sp,color="black")

fig = plt.figure(figsize = (6, 6))
ax = fig.add_subplot(111, projection='3d')
title1 = "Touhoku University 2024 Math qu.5 NO2"
ss=-1
ee=1.2
descartes_3d(ax, [ss, ee], [ss, ee], [ss, ee], title1)

ax.text(0,0,0.02,"O",fontsize=10)
ax.text(1,-0.07,0,"A",fontsize=10)
ax.text(0,1,0,"B",fontsize=10)
ax.text(1/2,0,1/2,"D",fontsize=10)

point_O = [0, 0, 0]
point_A = [1, 0, 0]
point_B = [0,1,0]
point_C = [0,0,1]
point_D = [1/2,0,1/2]

connect_line_3d(point_O,point_A,"-",color="black")
connect_line_3d(point_O,point_B,"-",color="black")
connect_line_3d(point_O,point_D,"-",color="black")
connect_line_3d(point_A,point_B,"-",color="black")
connect_line_3d(point_A,point_D,"-",color="black")
connect_line_3d(point_B,point_D,"-",color="black")

for t in np.linspace(0,1,100):
    dd=100
    mst=1

    if t>=1/2:
        Hy=0
        Hz=1-t
    else:
        Hy=1-2*t
        Hz=t
    point_G = [t,1-t,0]
    point_H = [t,Hy,Hz]
    connect_line_3d(point_G,point_H,"-",color="black")

    theta=np.linspace(0,2*np.pi,dd)
    x=[t]*len(theta)
    r=1-t
    y=r*np.sin(theta)
    z=r*np.cos(theta)
    ax.plot(x,y,z,"-",color="red",alpha=0.3)

    if t<=1/3:
        theta=np.linspace(0,2*np.pi,dd)
        x=[t]*len(theta)
        r=np.sqrt(5*t**2-4*t+1)
        y=r*np.sin(theta)
        z=r*np.cos(theta)
        ax.plot(x,y,z,"-",color="blue")
    else:
        theta=np.linspace(0,2*np.pi,dd)
        x=[t]*len(theta)
        r=(1-t)/np.sqrt(2)
        y=r*np.sin(theta)
        z=r*np.cos(theta)
        ax.plot(x,y,z,"-",color="green",ms=mst)

plt.show()