File size: 8,339 Bytes
d6889ed 889587a d6889ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from pymatgen.core import Structure
from pymatgen.analysis.diffraction.xrd import XRDCalculator
import tempfile # To create temporary files for download
import os
import traceback # For detailed error logging
# Define the core processing function
def generate_xrd_pattern(cif_file):
"""
Processes an uploaded CIF file, calculates the XRD pattern,
and returns a Plotly figure, a Pandas DataFrame, and the path to a CSV file.
Args:
cif_file: A file object from Gradio's gr.File component.
Returns:
tuple: (plotly_fig, dataframe, csv_filepath) or (None, None, None) if processing fails.
plotly_fig: A Plotly figure object.
dataframe: A Pandas DataFrame containing the peak data.
csv_filepath: Path to the generated temporary CSV file.
"""
if cif_file is None:
# Return None for all outputs if no file is uploaded
return None, None, None
try:
# Get the temporary path of the uploaded file
cif_filepath = cif_file.name
# 1. Load structure from CIF
structure = Structure.from_file(cif_filepath)
# 2. Calculate XRD pattern
calculator = XRDCalculator()
pattern = calculator.get_pattern(structure, two_theta_range=(10, 90)) # Adjust range if needed
# 3. Prepare data for DataFrame and Plot
miller_indices = []
for hkl_list in pattern.hkls:
if hkl_list:
# Format Miller indices: take the first set if multiple exist for a peak
#h, k, l = hkl_list[0]['hkl']
# Use standard tuple representation for display
miller_indices.append(str(tuple(hkl_list[0]['hkl'])))
# Alternative concise string: miller_indices.append(f"({h}{k}{l})")
else:
miller_indices.append("N/A")
# Round data for cleaner display
two_theta_rounded = [round(x, 3) for x in pattern.x]
intensity_rounded = [round(y, 3) for y in pattern.y]
data = pd.DataFrame({
"2θ (°)": two_theta_rounded,
"Intensity (norm)": intensity_rounded, # Assuming normalized intensity from pymatgen
"Miller Indices (hkl)": miller_indices
})
# --- Create Plotly Figure ---
fig = go.Figure()
fig.add_trace(go.Bar(
x=data["2θ (°)"],
y=data["Intensity (norm)"],
hovertext=[f"2θ: {t:.3f}<br>Intensity: {i:.1f}<br>hkl: {m}"
for t, i, m in zip(data["2θ (°)"], data["Intensity (norm)"], data["Miller Indices (hkl)"])],
hoverinfo="text", # Show only the custom hover text
width=0.1, # Slightly wider bars might look better
marker_color="#4682B4", # SteelBlue color
marker_line_width=0,
name='Peaks'
))
# Customize Layout
max_intensity = data["Intensity (norm)"].max() if not data.empty else 100
min_2theta = data["2θ (°)"].min() if not data.empty else 10
max_2theta = data["2θ (°)"].max() if not data.empty else 90
fig.update_layout(
title=dict(text=f"Simulated XRD Pattern: {structure.formula}", x=0.5, xanchor='center'), # Centered title
xaxis_title="2θ (°)",
yaxis_title="Intensity (Arb. Unit)",
xaxis_title_font_size=16,
yaxis_title_font_size=16,
xaxis=dict(
range=[min_2theta - 2, max_2theta + 2], # Slightly tighter range
showline=True, linewidth=1.5, linecolor='black', mirror=True,
ticks='outside', tickwidth=1.5, tickcolor='black',
tickfont_size=12
),
yaxis=dict(
range=[0, max_intensity * 1.05],
showline=True, linewidth=1.5, linecolor='black', mirror=True,
ticks='outside', tickwidth=1.5, tickcolor='black',
tickfont_size=12
),
plot_bgcolor='white',
paper_bgcolor='white', # Ensure background outside plot is also white
bargap=0.9, # Adjust gap based on new width
font=dict(family="Arial, sans-serif", size=12, color="black"),
margin=dict(l=70, r=30, t=60, b=70),
# Adjust height/width as needed, None allows more flexibility
height=450,
# width=None # Let Gradio manage width for responsiveness
)
fig.update_xaxes(showgrid=False, zeroline=False)
fig.update_yaxes(showgrid=False, zeroline=False)
# --- Create CSV File ---
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', newline='', encoding='utf-8') as temp_csv:
data.to_csv(temp_csv.name, index=False)
csv_filepath_out = temp_csv.name
# Return figure, dataframe, and csv path
return fig, data, csv_filepath_out
except Exception as e:
print(f"Error processing file: {e}") # Log error to console
traceback.print_exc() # Print detailed traceback
# Raise a Gradio error to display it in the UI
raise gr.Error(f"Failed to process CIF file. Please ensure it's a valid CIF. Error: {str(e)}")
# return None, None, None # Alternative: clear outputs
# --- Build Gradio Interface ---
# Use a theme for better aesthetics
theme = gr.themes.Soft(
primary_hue="sky", # Adjust colors if desired
secondary_hue="blue",
neutral_hue="slate"
)
with gr.Blocks(theme=theme, title="XRD Pattern Generator") as demo:
gr.Markdown(
"""
# XRD Pattern Simulator from CIF
Upload a Crystallographic Information File (.cif) to generate its simulated
X-ray Diffraction (XRD) pattern using [pymatgen](https://github.com/materialsproject/pymatgen).
"""
)
with gr.Row():
with gr.Column(scale=1): # Column for input
cif_input = gr.File(
label="Upload CIF File",
file_types=[".cif"],
type="filepath" # Use filepath directly
)
gr.Markdown("*(Example source: [Crystallography Open Database](http://crystallography.net/cod/))*")
with gr.Column(scale=3): # Column for outputs, make it wider
with gr.Tabs():
with gr.TabItem("📊 XRD Plot"):
# Wrap plot in a column/row to help with centering if needed,
# but Plotly's layout(title_x=0.5) is the primary centering method for the title.
# The plot component itself usually fills container width.
plot_output = gr.Plot(label="XRD Pattern") # Label might be redundant with Tab title
with gr.TabItem("📄 Peak Data Table"):
dataframe_output = gr.DataFrame(
label="Calculated Peak Data",
headers=["2θ (°)", "Intensity (norm)", "Miller Indices (hkl)"],
wrap=True, # Allow text wrapping for long indices
#max_rows=15, # Limit initial display height
#overflow_row_behaviour='paginate' # Add pagination if many rows
)
with gr.TabItem("⬇️ Download Data"):
csv_output = gr.File(label="Download Peak Data as CSV")
gr.Markdown("Click the link above to download the full data.")
# Clear outputs when input is cleared
cif_input.clear(
lambda: (None, None, None),
inputs=[],
outputs=[plot_output, dataframe_output, csv_output]
)
# Connect the input changes to the processing function
cif_input.change(
fn=generate_xrd_pattern,
inputs=cif_input,
outputs=[plot_output, dataframe_output, csv_output],
# show_progress="full" # Show progress indicator during calculation
)
examples = gr.Examples(
examples=[
["example_cif/NaCl_1000041.cif"],
["example_cif/Al2O3_1000017.cif"],
],
inputs=[cif_input],
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch()
# Add share=True for a public link: demo.launch(share=True) |