# # KdV_FFT_RK.py : # # A python 3 integrator for KdV. It uses FFT for spatial derivatives # and Runge-Kutta 4 for time evolution, and an integrating factor to # eliminate the dispersive term. # # NOTE: to get the animations working, best avoid jupyter - simply run # the script straight from a desktop terminal, or using ipython. # # Solves u_t = - 6 u u_x - u_xxx # # The main program was originally written by Sam Webster in AIMS South # Africa in 2007; modifications by Patrick Dorey in 2010, 2013 and 2014. # # Try changing N in the "initial conditions for u" section a little # way down; you can also experiment with completely different initial # conditions to see what happens... import numpy as np import matplotlib.pyplot as plt import time plt.ion() # turn off interaction mode for plotting ######################################################## # range of time, and time step: tmax = 0.5 dt = 0.00005 # number of points: (should be a power of 2 for the FFT) M = 512 # x period: L = 20.0 # x step size: h = L/M # approx time step between plots on screen: dtplot = 0.001 # x-axis points for plots: (note p.b.c. equates -L/2 to L/2) x = np.linspace(-L/2,L/2-h,M) ######################################################## # *** initial conditions for u *** : N = 4 u = N*(N+1)/np.cosh(x)**2 ######################################################## # initial y range for the plots: ymin = -0.1 ymax = 1.1*max(u) ######################################################## # Note that Uhat(k,t) = exp(-i(pi k /L)^3 t) uhat(k,t) where # uhat(k,t) is the FT of u(x,t). Hence Uhat(k,0)=uhat(k,0), # and the initial Uhat is simply the Fourier Transform of # the initial data: Uhat = np.fft.rfft(u) # The FT assumes the period to be 2 pi (hence the scaling) # and so the values of k range from -M/2 to M/2. However # since u is real we can use Python's "real" FFT, called # rfft. This drops the negative-frequency terms (which are # just the complex conjugates of the positive frequency ones) # and so k needs only to run from 0 to M/2: k = np.arange(M/2+1) ######################################################## # The function f calculates the right-hand-side of the Uhat # ODE. The FFT routines need the function u to have period # 2 pi; the factors of s=2 pi/L take this into account by # rescaling to/from period L: s = 2*np.pi/L sk = s*k A = 1j*sk**3 B = -3j*sk def f(tt,uu): ee = np.exp(A*tt) a1 = np.fft.irfft(ee*uu) a2 = np.fft.rfft(a1**2) return B*(1/ee)*a2 ######################################################## # commands to initialise plot: line, = plt.plot(x,u) line.axes.set_ylim(ymin,ymax) # a counter so that only every Kth configuration is plotted: c = 0 K = int(dtplot/dt) # default message that the program has finished successfully: stop = 'run completed' ###################################### # START OF INTEGRATION ROUTINE # ###################################### start = time.process_time() try: for t in np.arange(0.0,tmax+dt,dt): # Solve (d/dt)Uhat(k,t)=f(t,Uhat) using a 4th-order Runge-Kutta # method in time, where f(,) is given by the above routine: k1 = f(t,Uhat) k2 = f(t+0.5*dt,Uhat+0.5*dt*k1) k3 = f(t+0.5*dt,Uhat+0.5*dt*k2) k4 = f(t+dt,Uhat+dt*k3) Uhat += (dt/6)*(k1+2*k2+2*k3+k4) # Every Kth configuration, compute u, check it and plot it: if c%K==0: # Compute u: e = np.exp(A*t) uhat = e*Uhat u = np.fft.irfft(uhat) # Is u blowing up? If so, interrupt: if abs(uhat[-1])>200: stop = 'stopped early: unstable. Decrease the time step!' break # Has u gone off the screen? If so, increase y-range for plot: um = max(u) if um > ymax: ymax = ymax*1.5 line.axes.set_ylim(ymin,ymax) # Plot u: plt.title('t='+'%.3f'%t) line.set_ydata(u.real) plt.draw() plt.pause(0.0001) c += 1 # Allow the user to interrupt the program: except KeyboardInterrupt: stopearly = 'stopped early: keyboard interrupt' runtime=(time.process_time()-start) #################################### # END OF INTEGRATION ROUTINE # #################################### ######################################################## # end by printing out some information about the run: print(stop) print('run time: '+'%.5f' % runtime) print('t = '+'%.4f' % t) print('M = '+'%i' % M) print('max u = '+'%.4f' % um) i=np.argmax(u) print('index of max u = '+'%i' % np.argmax(u)) print('location of max u : '+'%.4f' % x[i]) umaxpos=x[i]+(u[i-1]-u[i+1])/(u[i-1]-2*u[i]+u[i+1])*h/2 print('corrected location of max u : '+'%.4f' % umaxpos) ########################################################