Spaces:
Running
Running
# At the top of app.py, add debug printing | |
import os | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Summarizer: | |
def __init__(self): | |
try: | |
# Print current directory contents for debugging | |
st.write("Current directory contents:") | |
st.write(os.listdir('.')) | |
# Base model | |
self.base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
local_files_only=False # Allow downloading base model | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
"GanjinZero/biobart-base", | |
local_files_only=False | |
) | |
# Load adapter config from local file | |
adapter_config_path = "./adapter_config.json" | |
if not os.path.exists(adapter_config_path): | |
st.error(f"adapter_config.json not found in {os.getcwd()}") | |
raise FileNotFoundError("adapter_config.json not found") | |
st.write(f"Loading adapter config from {adapter_config_path}") | |
# Create LoRA config | |
lora_config = LoraConfig( | |
r=8, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
bias="none", | |
task_type="SEQ_2_SEQ_LM", | |
target_modules=["q_proj", "v_proj"], | |
inference_mode=True | |
) | |
# Load base model for fine-tuning | |
base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
local_files_only=False | |
) | |
st.write("Loading fine-tuned model...") | |
# Try to load the PEFT model from the current directory | |
self.finetuned_model = PeftModel.from_pretrained( | |
base_model_for_finetuned, | |
".", # Current directory | |
config=lora_config, | |
torch_dtype=torch.float32, | |
is_trainable=False, | |
local_files_only=True | |
) | |
self.finetuned_model.eval() | |
st.success("Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.write("Debug info:") | |
st.write(f"Current working directory: {os.getcwd()}") | |
st.write(f"Directory contents: {os.listdir('.')}") | |
if os.path.exists('adapter_config.json'): | |
with open('adapter_config.json', 'r') as f: | |
st.write("adapter_config.json contents:", f.read()) | |
raise |