File size: 8,339 Bytes
d6889ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889587a
d6889ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from pymatgen.core import Structure
from pymatgen.analysis.diffraction.xrd import XRDCalculator
import tempfile # To create temporary files for download
import os
import traceback # For detailed error logging

# Define the core processing function
def generate_xrd_pattern(cif_file):
    """
    Processes an uploaded CIF file, calculates the XRD pattern,
    and returns a Plotly figure, a Pandas DataFrame, and the path to a CSV file.

    Args:
        cif_file: A file object from Gradio's gr.File component.

    Returns:
        tuple: (plotly_fig, dataframe, csv_filepath) or (None, None, None) if processing fails.
               plotly_fig: A Plotly figure object.
               dataframe: A Pandas DataFrame containing the peak data.
               csv_filepath: Path to the generated temporary CSV file.
    """
    if cif_file is None:
        # Return None for all outputs if no file is uploaded
        return None, None, None

    try:
        # Get the temporary path of the uploaded file
        cif_filepath = cif_file.name

        # 1. Load structure from CIF
        structure = Structure.from_file(cif_filepath)

        # 2. Calculate XRD pattern
        calculator = XRDCalculator()
        pattern = calculator.get_pattern(structure, two_theta_range=(10, 90)) # Adjust range if needed

        # 3. Prepare data for DataFrame and Plot
        miller_indices = []
        for hkl_list in pattern.hkls:
             if hkl_list:
                 # Format Miller indices: take the first set if multiple exist for a peak
                 #h, k, l = hkl_list[0]['hkl']
                 # Use standard tuple representation for display
                 miller_indices.append(str(tuple(hkl_list[0]['hkl'])))
                 # Alternative concise string: miller_indices.append(f"({h}{k}{l})")
             else:
                 miller_indices.append("N/A")

        # Round data for cleaner display
        two_theta_rounded = [round(x, 3) for x in pattern.x]
        intensity_rounded = [round(y, 3) for y in pattern.y]

        data = pd.DataFrame({
            "2θ (°)": two_theta_rounded,
            "Intensity (norm)": intensity_rounded, # Assuming normalized intensity from pymatgen
            "Miller Indices (hkl)": miller_indices
        })

        # --- Create Plotly Figure ---
        fig = go.Figure()

        fig.add_trace(go.Bar(
            x=data["2θ (°)"],
            y=data["Intensity (norm)"],
            hovertext=[f"2θ: {t:.3f}<br>Intensity: {i:.1f}<br>hkl: {m}"
                       for t, i, m in zip(data["2θ (°)"], data["Intensity (norm)"], data["Miller Indices (hkl)"])],
            hoverinfo="text", # Show only the custom hover text
            width=0.1, # Slightly wider bars might look better
            marker_color="#4682B4", # SteelBlue color
            marker_line_width=0,
            name='Peaks'
        ))

        # Customize Layout
        max_intensity = data["Intensity (norm)"].max() if not data.empty else 100
        min_2theta = data["2θ (°)"].min() if not data.empty else 10
        max_2theta = data["2θ (°)"].max() if not data.empty else 90

        fig.update_layout(
            title=dict(text=f"Simulated XRD Pattern: {structure.formula}", x=0.5, xanchor='center'), # Centered title
            xaxis_title="2θ (°)",
            yaxis_title="Intensity (Arb. Unit)",
            xaxis_title_font_size=16,
            yaxis_title_font_size=16,
            xaxis=dict(
                range=[min_2theta - 2, max_2theta + 2], # Slightly tighter range
                showline=True, linewidth=1.5, linecolor='black', mirror=True,
                ticks='outside', tickwidth=1.5, tickcolor='black',
                tickfont_size=12
            ),
            yaxis=dict(
                range=[0, max_intensity * 1.05],
                showline=True, linewidth=1.5, linecolor='black', mirror=True,
                ticks='outside', tickwidth=1.5, tickcolor='black',
                tickfont_size=12
            ),
            plot_bgcolor='white',
            paper_bgcolor='white', # Ensure background outside plot is also white
            bargap=0.9, # Adjust gap based on new width
            font=dict(family="Arial, sans-serif", size=12, color="black"),
            margin=dict(l=70, r=30, t=60, b=70),
            # Adjust height/width as needed, None allows more flexibility
            height=450,
            # width=None # Let Gradio manage width for responsiveness
        )
        fig.update_xaxes(showgrid=False, zeroline=False)
        fig.update_yaxes(showgrid=False, zeroline=False)


        # --- Create CSV File ---
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', newline='', encoding='utf-8') as temp_csv:
            data.to_csv(temp_csv.name, index=False)
            csv_filepath_out = temp_csv.name

        # Return figure, dataframe, and csv path
        return fig, data, csv_filepath_out

    except Exception as e:
        print(f"Error processing file: {e}") # Log error to console
        traceback.print_exc() # Print detailed traceback
        # Raise a Gradio error to display it in the UI
        raise gr.Error(f"Failed to process CIF file. Please ensure it's a valid CIF. Error: {str(e)}")
        # return None, None, None # Alternative: clear outputs


# --- Build Gradio Interface ---
# Use a theme for better aesthetics
theme = gr.themes.Soft(
    primary_hue="sky",       # Adjust colors if desired
    secondary_hue="blue",
    neutral_hue="slate"
)

with gr.Blocks(theme=theme, title="XRD Pattern Generator") as demo:
    gr.Markdown(
        """
        #  XRD Pattern Simulator from CIF
        Upload a Crystallographic Information File (.cif) to generate its simulated
        X-ray Diffraction (XRD) pattern using [pymatgen](https://github.com/materialsproject/pymatgen).
        """
    )

    with gr.Row():
        with gr.Column(scale=1): # Column for input
            cif_input = gr.File(
                label="Upload CIF File",
                file_types=[".cif"],
                type="filepath" # Use filepath directly
            )
            gr.Markdown("*(Example source: [Crystallography Open Database](http://crystallography.net/cod/))*")

        with gr.Column(scale=3): # Column for outputs, make it wider
             with gr.Tabs():
                with gr.TabItem("📊 XRD Plot"):
                    # Wrap plot in a column/row to help with centering if needed,
                    # but Plotly's layout(title_x=0.5) is the primary centering method for the title.
                    # The plot component itself usually fills container width.
                    plot_output = gr.Plot(label="XRD Pattern") # Label might be redundant with Tab title

                with gr.TabItem("📄 Peak Data Table"):
                    dataframe_output = gr.DataFrame(
                        label="Calculated Peak Data",
                        headers=["2θ (°)", "Intensity (norm)", "Miller Indices (hkl)"],
                        wrap=True, # Allow text wrapping for long indices
                        #max_rows=15, # Limit initial display height
                        #overflow_row_behaviour='paginate' # Add pagination if many rows
                        )

                with gr.TabItem("⬇️ Download Data"):
                    csv_output = gr.File(label="Download Peak Data as CSV")
                    gr.Markdown("Click the link above to download the full data.")


    # Clear outputs when input is cleared
    cif_input.clear(
        lambda: (None, None, None),
        inputs=[],
        outputs=[plot_output, dataframe_output, csv_output]
    )

    # Connect the input changes to the processing function
    cif_input.change(
        fn=generate_xrd_pattern,
        inputs=cif_input,
        outputs=[plot_output, dataframe_output, csv_output],
        # show_progress="full" # Show progress indicator during calculation
    )
    examples = gr.Examples(
        examples=[
            ["example_cif/NaCl_1000041.cif"],
            ["example_cif/Al2O3_1000017.cif"],
        ],
        inputs=[cif_input],
    )

# --- Launch the App ---
if __name__ == "__main__":
    demo.launch()
    # Add share=True for a public link: demo.launch(share=True)