#%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle

fig=plt.figure(figsize=(5,5))
ax = fig.add_subplot(111, projection="3d")

t = np.linspace(0, 2*np.pi, 100)
z = np.linspace(-1, 1, 100)
t, Z=np.meshgrid(t, z)
a=1
b=1
p=1
c, s = np.cos(t), np.sin(t)
x = np.abs(c)**(2/p) * a * np.sign(c) 
y = np.abs(s)**(2/p) * b * np.sign(s) 

x_s = x*np.cos(np.pi/4) - y*np.sin(np.pi/4)
y_s = x*np.sin(np.pi/4) + y*np.sin(np.pi/4)


futa = Rectangle(xy=(np.min(x_s),np.min(y_s)), width=np.max(x_s)*2, height=np.max(y_s)*2,fc='g',ec='g',alpha=0.5)
ax.add_patch(futa)
art3d.pathpatch_2d_to_3d(futa, z=np.max(z), zdir="z")

soko = Rectangle(xy=(np.min(x_s),np.min(y_s)), width=np.max(x_s)*2, height=np.max(y_s)*2,fc='g',ec='g',alpha=0.5)
ax.add_patch(soko)
art3d.pathpatch_2d_to_3d(soko, z=np.min(z), zdir="z")

r= np.sqrt(3)
theta_2_0 = np.linspace(0, 2*np.pi, 100) 
theta_1_0 = np.linspace(0, 2*np.pi, 100)  
theta_1, theta_2 = np.meshgrid(theta_1_0, theta_2_0)
x2 = np.cos(theta_2)*np.sin(theta_1) * r
y2 = np.sin(theta_2)*np.sin(theta_1) * r
z2 = np.cos(theta_1) * r
ax.plot_surface(x2,y2,z2, color='y',alpha=0.3)

ax.set_xlabel('x axis')
ax.set_ylabel('y axis')
ax.set_zlabel('z axis')
ax.set_xlim(-2,2)
ax.set_ylim(-2,2)
ax.set_zlim(-2,2)

ax.plot_surface(x_s, y_s, Z, color='g',alpha=0.7)

plt.show()