Spaces:
Running
Running
| from os import path | |
| from PIL import Image | |
| from PyQt5.QtCore import Qt | |
| from PyQt5.QtWidgets import ( | |
| QApplication, | |
| QHBoxLayout, | |
| QVBoxLayout, | |
| QGridLayout, | |
| QPushButton, | |
| QSlider, | |
| QLabel, | |
| QFrame, | |
| QComboBox, | |
| QWidget, | |
| QSizePolicy, | |
| QMessageBox, | |
| ) | |
| from backend.lora import ( | |
| get_lora_models, | |
| get_active_lora_weights, | |
| update_lora_weights, | |
| load_lora_weight, | |
| ) | |
| from frontend.gui.common_widgets import LabeledSlider | |
| from app_settings import AppSettings | |
| from paths import FastStableDiffusionPaths | |
| if __name__ != "__main__": | |
| from state import get_settings, get_context | |
| from models.interface_types import InterfaceType | |
| # app_settings = get_settings() | |
| _MAX_LORA_WEIGHTS = 5 | |
| _current_lora_count = 0 | |
| _active_lora_widgets = [] | |
| # This is a simple widget for displaying the loaded LoRAs name and weight | |
| class _LoraWidget(QWidget): | |
| def __init__(self): | |
| super().__init__() | |
| self.name_label = QLabel() | |
| self.strength_slider = LabeledSlider(True) | |
| hlayout = QHBoxLayout() | |
| hlayout.addWidget(self.name_label) | |
| hlayout.addWidget(self.strength_slider) | |
| self.setLayout(hlayout) | |
| def setValues(self, name: str, weight: float): | |
| self.name_label.setText(name) | |
| self.strength_slider.setValue(weight) | |
| def getValues(self): | |
| return (self.name_label.text(), self.strength_slider.getValue()) | |
| class LoraModelsWidget(QWidget): | |
| def __init__(self, config: AppSettings, parent): | |
| super().__init__() | |
| self.parent = parent | |
| self.config = config | |
| lora_models_map = {} | |
| if config != None: | |
| lora_models_map = get_lora_models( | |
| config.settings.lcm_diffusion_setting.lora.models_dir | |
| ) | |
| self.models_combobox = QComboBox() | |
| self.models_combobox.addItems(lora_models_map.keys()) | |
| self.models_combobox.setToolTip( | |
| "<p style='white-space:pre'>Place LoRA models in the <b>lora_models</b> folder</p>" | |
| ) | |
| self.weight_slider = LabeledSlider(True) | |
| self.load_button = QPushButton("Load selected LoRA") | |
| self.load_button.setEnabled(False) | |
| self.load_button.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) | |
| self.load_button.setStyleSheet("padding: 10px") | |
| self.load_button.clicked.connect(self.on_load_lora) | |
| if len(lora_models_map) > 0: | |
| self.load_button.setEnabled(True) | |
| self.loaded_label = QLabel("Loaded LoRA models:") | |
| self.update_button = QPushButton("Update LoRA weights") | |
| self.update_button.setEnabled(False) | |
| self.update_button.clicked.connect(self.on_update_weights) | |
| self.separator = QLabel() | |
| self.separator.setFrameShape(QFrame.HLine) | |
| glayout = QGridLayout() | |
| glayout.setVerticalSpacing(0) | |
| glayout.addWidget(QLabel("LoRA model:"), 0, 0) | |
| glayout.addWidget( | |
| QLabel( | |
| "Initial LoRA weight:", | |
| ), | |
| 0, | |
| 1, | |
| ) | |
| glayout.addWidget(self.models_combobox, 1, 0) | |
| glayout.addWidget(self.weight_slider, 1, 1) | |
| glayout.addWidget(self.load_button, 0, 2, 2, 1) | |
| hlayout = QHBoxLayout() | |
| hlayout.addWidget(self.loaded_label) | |
| hlayout.addWidget(self.update_button) | |
| vlayout = QVBoxLayout() | |
| vlayout.addLayout(glayout, 10) | |
| vlayout.addWidget(self.separator, 1) | |
| vlayout.addLayout(hlayout, 10) | |
| vlayout.addStretch(80) | |
| self.setLayout(vlayout) | |
| def on_load_lora(self): | |
| # Code for testing the GUI; ignore when running FastSD CPU | |
| if __name__ == "__main__": | |
| self.layout().insertWidget(3, _LoraWidget()) | |
| return | |
| # End of code for testing the GUI | |
| global _current_lora_count | |
| global _active_lora_widgets | |
| if ( | |
| self.config == None | |
| or self.config.settings == None | |
| or _current_lora_count >= _MAX_LORA_WEIGHTS | |
| ): | |
| return | |
| if self.config.settings.lcm_diffusion_setting.use_openvino: | |
| QMessageBox().information( | |
| self.parent, | |
| "Error", | |
| "LoRA suppport is currently not implemented for OpenVINO.", | |
| ) | |
| return | |
| lora_models_map = get_lora_models( | |
| self.config.settings.lcm_diffusion_setting.lora.models_dir | |
| ) | |
| # Load a new LoRA | |
| settings = self.config.settings.lcm_diffusion_setting | |
| settings.lora.fuse = False | |
| settings.lora.enabled = False | |
| current_lora = self.models_combobox.currentText() | |
| current_weight = self.weight_slider.getValue() | |
| print(f"Selected Lora Model :{current_lora}") | |
| print(f"Lora weight :{current_weight}") | |
| settings.lora.path = lora_models_map[current_lora] | |
| settings.lora.weight = current_weight | |
| if not path.exists(settings.lora.path): | |
| QMessageBox.information(self.parent, "Error", "Invalid LoRA model path!") | |
| return | |
| if not self.parent.context.lcm_text_to_image.pipeline: | |
| QMessageBox.information( | |
| self.parent, | |
| "Error", | |
| "Pipeline not initialized. Please generate an image first.", | |
| ) | |
| return | |
| settings.lora.enabled = True | |
| load_lora_weight( | |
| self.parent.context.lcm_text_to_image.pipeline, | |
| settings, | |
| ) | |
| lora_widget = _LoraWidget() | |
| lora_widget.setValues(current_lora, current_weight) | |
| self.layout().insertWidget(3, lora_widget) | |
| self.update_button.setEnabled(True) | |
| _active_lora_widgets.append(lora_widget) | |
| _current_lora_count += 1 | |
| def on_update_weights(self): | |
| update_weights = [] | |
| active_weights = get_active_lora_weights() | |
| if not len(active_weights): | |
| return | |
| global _active_lora_widgets | |
| for idx, lora in enumerate(active_weights): | |
| update_weights.append( | |
| ( | |
| lora[0], | |
| _active_lora_widgets[idx].getValues()[1], | |
| ) | |
| ) | |
| if len(update_weights) > 0: | |
| update_lora_weights( | |
| self.parent.context.lcm_text_to_image.pipeline, | |
| self.config.settings.lcm_diffusion_setting, | |
| update_weights, | |
| ) | |
| def reset_active_lora_widgets(self): | |
| # This code assumes that the only time when the active LoRA weights count | |
| # is different from the current LoRA GUI widgets count is after a pipeline | |
| # rebuild, when the active LoRA widgets count will be zero, so all LoRA GUI | |
| # widgets are simply removed with no further action | |
| global _current_lora_count | |
| global _active_lora_widgets | |
| if len(get_active_lora_weights()) != _current_lora_count: | |
| for lora_widget in _active_lora_widgets: | |
| self.layout().removeWidget(lora_widget) | |
| _current_lora_count = 0 | |
| _active_lora_widgets = [] | |
| # Test the widget | |
| if __name__ == "__main__": | |
| import sys | |
| app = QApplication(sys.argv) | |
| widget = LoraModelsWidget(None, None) | |
| widget.show() | |
| app.exec() | |