Spaces:
Runtime error
Runtime error
import numpy as np | |
import periodictable | |
from copy import deepcopy, copy | |
from .latmatcher import LatMatch, SR | |
class PipelineLatMatch: | |
def __init__(self, Alat3D, Blat3D, Aatoms3D=None, Batoms3D=None, dim=10, sc_vec=None, optimize_angle=True, | |
optimize_strain=True, dz=4): | |
self.Alat3D = Alat3D | |
self.Blat3D = Blat3D | |
self.Alat = self.Alat3D[:2, :2][:2, :2] | |
self.Blat = self.Blat3D[:2, :2][:2, :2] | |
self.dz = dz | |
if sc_vec is not None: | |
self.sc_vec = sc_vec | |
# TODO: This should be computed | |
self.rez = None | |
else: | |
self.matcher4 = LatMatch(scdim=dim, reference=self.Alat, target=self.Blat, | |
optimize_angle=optimize_angle, optimize_strain=optimize_strain) | |
self.sc_vec = self.matcher4.supercell() | |
self.rez = self.matcher4.result | |
self.sc_vec3 = np.array([[self.sc_vec[0][0], self.sc_vec[0][1], 0], | |
[self.sc_vec[1][0], self.sc_vec[1][1], 0], | |
[0, 0, self.dz+self.Alat3D[2][2]+self.Blat3D[2][2]+0.001]]) | |
self.Aatoms3D, self.Batoms3D = self.atom_shift(Aatoms3D, Batoms3D) | |
self.superA_xyz = [] | |
self.superB_xyz = [] | |
def atom_shift(self, Aatoms3D, Batoms3D): | |
shift_Aatoms3D = deepcopy(Aatoms3D) | |
shift_Batoms3D = deepcopy(Batoms3D) | |
za = [] | |
for atom in Aatoms3D: | |
za.append(atom[1][-1]) | |
zb = [] | |
for atom in Batoms3D: | |
zb.append(atom[1][-1]) | |
a_average = sum(za) / len(za) | |
b_average = sum(zb) / len(zb) | |
min_za=min(za) # min(za)+x=0 | |
for i in range(len(za)): | |
za[i] = za[i] - min_za+0.001 | |
min_zb = min(zb) # min(za)+x=0 | |
for i in range(len(zb)): | |
zb[i] = zb[i] - min_zb | |
max_za = max(za) | |
for i in range(len(zb)): | |
zb[i] = zb[i] + max_za + self.dz+0.001 | |
for i in range(len(shift_Aatoms3D)): | |
shift_Aatoms3D[i][1][-1] = za[i] | |
for i in range(len(shift_Batoms3D)): | |
shift_Batoms3D[i][1][-1] = zb[i] | |
self.Aatoms3D = shift_Aatoms3D | |
self.Batoms3D = shift_Batoms3D | |
return self.Aatoms3D, self.Batoms3D | |
def compute_super_atoms(self): | |
atoms = atoms_to_greed(self.Aatoms3D, lat_v=self.Alat3D, dim=(10, 10, 0)) # initial grid of atoms xyz | |
atoms_a = atom_change_basis2D(atoms, new_basis=self.Alat, old_basis=np.identity(2)) # grid of atoms in A basis | |
atoms_A = atom_change_basis2D(atoms_a, new_basis=self.sc_vec, | |
old_basis=self.Alat) # grid of atoms in super cell basis | |
superA = supar_atoms(atoms_A) # select the atoms from super cell | |
superA_xyz = atom_change_basis2D(superA, new_basis=np.identity(2), | |
old_basis=self.sc_vec) # move the atoms back to xyz basis | |
atoms = atoms_to_greed(self.Batoms3D, lat_v=self.Blat3D, dim=(10, 10, 0)) # initial grid of atoms | |
oBlat3D, atoms_b = rotate_guest(self.rez, self.Blat3D, atoms) | |
atoms_B = atom_change_basis2D(atoms_b, new_basis=self.sc_vec, | |
old_basis=np.identity(2)) # grid of atoms in super cell basis | |
superB = supar_atoms(atoms_B) # select the atoms from super cell | |
superB_xyz = atom_change_basis2D(superB, new_basis=np.identity(2), | |
old_basis=self.sc_vec) # move the atoms back to xyz basis | |
superA_xyz = uniq_list(superA_xyz) | |
superB_xyz = uniq_list(superB_xyz) | |
self.superA_xyz = superA_xyz | |
self.superB_xyz = superB_xyz | |
return self.superA_xyz, self.superB_xyz | |
def get_new_structure(self): | |
structure = {"atoms": [], | |
"lattice_vectors": self.sc_vec3, | |
"pbc": [True, True, False], | |
"positions": [], | |
"host_guest": [], } | |
superA_xyz, superB_xyz = self.compute_super_atoms() | |
atomic_numbers = [] | |
positions = [] | |
for element in superA_xyz: | |
structure["host_guest"].append("host") | |
symbol = element[0] | |
positions.append(element[1].tolist()) | |
atomic_number = getattr(periodictable, symbol) | |
if atomic_number is not None: | |
atomic_numbers.append(atomic_number.number) | |
else: | |
print(f"Warning: Atomic number for element '{symbol}' not found.") | |
atomic_numbers.append(None) # Or handle this case as you see fit | |
for element in superB_xyz: | |
structure["host_guest"].append("guest") | |
symbol = element[0] | |
positions.append(element[1].tolist()) | |
atomic_number = getattr(periodictable, symbol) | |
if atomic_number is not None: | |
atomic_numbers.append(atomic_number.number) | |
else: | |
print(f"Warning: Atomic number for element '{symbol}' not found.") | |
atomic_numbers.append(None) # Or handle this case as you see fit | |
structure["atoms"] = atomic_numbers | |
structure["positions"] = positions | |
return structure | |
def is_element_in_list(element, lst): | |
for item in lst: | |
if item[0] == element[0] and np.allclose(item[1], element[1]): | |
return True | |
return False | |
def uniq_list(super_brut): | |
uniq = [] | |
for c in super_brut: | |
if is_element_in_list(c, uniq) is False: | |
uniq.append(c) | |
return uniq | |
def atoms_to_greed(atoms, lat_v, dim): | |
""" | |
Construct a greed of atoms knowing the atoms lattice vectors and dim | |
:param atoms: | |
:param lat_v: | |
:param dim: | |
:return: | |
""" | |
atom_list = deepcopy(atoms) | |
# translation lat_v-x | |
new_atoms = [] | |
for i in range(1, dim[0] + 1): | |
for atom in atom_list: | |
new_atom = deepcopy(atom) | |
new_atom[1][0] += i * lat_v[0][0] | |
new_atom[1][1] += i * lat_v[1][0] | |
new_atom[1][2] += i * lat_v[2][0] | |
new_atoms.append(new_atom) | |
new_atom = deepcopy(atom) | |
new_atom[1][0] -= i * lat_v[0][0] | |
new_atom[1][1] -= i * lat_v[1][0] | |
new_atom[1][2] -= i * lat_v[2][0] | |
new_atoms.append(new_atom) | |
atom_list.extend(new_atoms) | |
# translation lat_v-y | |
for i in range(1, dim[1] + 1): | |
for atom in atom_list: | |
new_atom = deepcopy(atom) | |
new_atom[1][0] += i * lat_v[0][1] | |
new_atom[1][1] += i * lat_v[1][1] | |
new_atom[1][2] += i * lat_v[2][1] | |
new_atoms.append(new_atom) | |
new_atom = deepcopy(atom) | |
new_atom[1][0] -= i * lat_v[0][1] | |
new_atom[1][1] -= i * lat_v[1][1] | |
new_atom[1][2] -= i * lat_v[2][1] | |
new_atoms.append(new_atom) | |
atom_list.extend(new_atoms) | |
# translation lat_v-z | |
for i in range(1, dim[2] + 1): | |
for atom in atom_list: | |
new_atom = deepcopy(atom) | |
new_atom[1][0] += i * lat_v[0][2] | |
new_atom[1][1] += i * lat_v[1][2] | |
new_atom[1][2] += i * lat_v[2][2] | |
new_atoms.append(new_atom) | |
new_atom = deepcopy(atom) | |
new_atom[1][0] -= i * lat_v[0][2] | |
new_atom[1][1] -= i * lat_v[1][2] | |
new_atom[1][2] -= i * lat_v[2][2] | |
new_atoms.append(new_atom) | |
atom_list.extend(new_atoms) | |
return atom_list | |
# atoms in new basis: | |
def atom_change_basis2D(atoms, new_basis, old_basis=np.identity(2)): | |
new_basis = [[new_basis[0][0], new_basis[0][1]], | |
[new_basis[1][0], new_basis[1][1]]] | |
old_basis = [[old_basis[0][0], old_basis[0][1]], | |
[old_basis[1][0], old_basis[1][1]]] | |
atom_list = deepcopy(atoms) | |
change_base = np.linalg.inv(new_basis) @ old_basis | |
new_atoms = [] | |
for atom in atom_list: | |
new_atom = deepcopy(atom) | |
nd = (change_base @ (np.array(new_atom[1][:2]).T)).T | |
new_atom[1][0] = nd[0] | |
new_atom[1][1] = nd[1] | |
new_atoms.append(new_atom) | |
atom_list = deepcopy(new_atoms) | |
return atom_list | |
def supar_atoms(atoms, eps=0.01): | |
atom_list = deepcopy(atoms) | |
# change_base=np.linalg.inv(new_basis) | |
new_atoms = [] | |
for atom in atom_list: | |
if ((atom[1][0] >= -eps) and (atom[1][0] < 1)) and ((atom[1][1] >= -eps) and (atom[1][1] < 1)): | |
new_atoms.append(atom) | |
atom_list = deepcopy(new_atoms) | |
return atom_list | |
def rotate_guest(rez, Blat3D, atoms): | |
Batoms3D = [a[1] for a in atoms] | |
s1, s2, theta = rez | |
Tr = np.eye(3) | |
Tr[:2, :2] = SR(s1, s2, theta) | |
oBlat3D = copy(Blat3D) | |
oBlat3D = (Tr @ (oBlat3D.T)).T | |
rs = [Tr @ r for r in Batoms3D] | |
new_atoms = [[atoms[i][0], rs[i]] for i in range(len(atoms))] | |
return oBlat3D, new_atoms | |