Source code for cifkit.preprocessors.environment

import numpy as np

from cifkit.utils import distance, unit


[docs] def get_site_connections( parsed_data: list[str], unitcell_points, supercell_points, cutoff_radius: float, ) -> dict: """Compute all pair distances per site label.""" labels, lengths, angles = parsed_data all_labels_connections = {} for site_label in labels: filtered_unitcell_points = [ point for point in unitcell_points if point[3] == site_label ] dist_result = get_nearest_dists_per_site( filtered_unitcell_points, supercell_points, cutoff_radius, lengths, angles, ) dist_dict, dist_set = dist_result ( label, connections, ) = get_most_connected_point_per_site(site_label, dist_dict, dist_set) all_labels_connections[label] = connections return remove_duplicate_connections(all_labels_connections)
[docs] def get_nearest_dists_per_site( filtered_unitcell_points, supercell_points, cutoff_radius: float, lengths, angles_rad, ): # Initialize a dictionary to store the relationships dist_dict = {} dist_set = set() # Loop through each point in the filtered list for i, point_1 in enumerate(filtered_unitcell_points): point_2_info = [] for j, point_2 in enumerate(supercell_points): if point_1 == point_2: continue # Skip comparison with itself # Convert fractional to Cartesian coordinates cart_1 = unit.fractional_to_cartesian( [point_1[0], point_1[1], point_1[2]], lengths, angles_rad, ) cart_2 = unit.fractional_to_cartesian( [point_2[0], point_2[1], point_2[2]], lengths, angles_rad, ) # Calculate the dist between two points dist = distance.calc_dist_two_cart_points(cart_1, cart_2) dist = float(np.round(dist, 3)) # Check the dist if dist < cutoff_radius and dist > 0.1: point_2_info.append( ( point_2[3], # site label dist, [ float(np.round(cart_1[0], 3)), # x float(np.round(cart_1[1], 3)), # y float(np.round(cart_1[2], 3)), # z ], [ float(np.round(cart_2[0], 3)), # x float(np.round(cart_2[1], 3)), # y float(np.round(cart_2[2], 3)), # z ], ) ) dist_set.add(dist) # Store the list in the dictionary with `i` as the key if point_2_info: dist_dict[i] = point_2_info return dist_dict, dist_set
[docs] def get_most_connected_point_per_site( label: str, dist_dict: dict, dist_set: set ): """Identify the reference point with the highest number of connections within the 50 shortest distances from a set of distances.""" sorted_unique_dists = sorted(dist_set) # Get the 30 shortest distances shortest_dists = sorted_unique_dists[:50] # Variables to track the reference point with the highest count max_count = 0 max_ref_point = None max_connections = [] for ref_idx, connections in dist_dict.items(): # Initialize a dictionary to count occurrences of each shortest dist_counts = {dist: 0 for dist in shortest_dists} # Count the occurrences of the shortest distances for _, dist, _, _ in connections: if dist in dist_counts: dist_counts[dist] += 1 # Calculate the total count of occurrences for this reference point total_count = sum(dist_counts.values()) # Check if this is the maximum we've encountered so far if total_count > max_count: max_count = total_count max_ref_point = ref_idx max_connections = sorted(connections, key=lambda x: x[1]) # Return the max point if max_ref_point is not None: return label, [ (other_label, dist, cart_1, cart_2) for other_label, dist, cart_1, cart_2 in max_connections ]
[docs] def remove_duplicate_connections(connections): """Remove duplicate connections based on the last set of coordinates.""" unique_connections = {} for key, value in connections.items(): seen = set() unique_list = [] for item in value: # The tuple representing the endpoint coordinates is item[3] coords = tuple( item[3] ) # Need to convert list to tuple to use it in a set if coords not in seen: seen.add(coords) unique_list.append(item) unique_connections[key] = unique_list return unique_connections