Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,58 +1,74 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 3 | 
            -
            from peft import PeftModel, PeftConfig
         | 
| 4 | 
             
            import gradio as gr
         | 
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
            peft_config = PeftConfig.from_pretrained(adapter_id)
         | 
| 9 |  | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
|  | |
| 12 |  | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
                torch_dtype=torch.float16,
         | 
| 17 | 
            -
                device_map="auto"
         | 
| 18 | 
            -
            )
         | 
| 19 |  | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 23 |  | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
                 | 
| 27 | 
            -
                 | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
                         | 
| 38 | 
            -
                         | 
| 39 | 
            -
                         | 
| 40 | 
            -
                        top_p=0.9,
         | 
| 41 | 
            -
                        do_sample=True,
         | 
| 42 | 
            -
                        pad_token_id=tokenizer.eos_token_id
         | 
| 43 | 
             
                    )
         | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
                 | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
                 | 
| 57 | 
            -
                 | 
| 58 | 
            -
            ). | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import re
         | 
| 5 |  | 
| 6 | 
            +
            model_path = "./depression_model_part1"
         | 
| 7 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
|  | |
| 8 |  | 
| 9 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_path)
         | 
| 10 | 
            +
            model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
         | 
| 11 | 
            +
            model.eval()
         | 
| 12 |  | 
| 13 | 
            +
            user_history = []
         | 
| 14 | 
            +
            turn_counter = 0
         | 
| 15 | 
            +
            MAX_TURNS_FOR_PREDICTION = 8
         | 
|  | |
|  | |
|  | |
| 16 |  | 
| 17 | 
            +
            def chat(user_input):
         | 
| 18 | 
            +
                global user_history, turn_counter
         | 
| 19 | 
            +
                turn_counter += 1
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                user_history.append(f"Human: {user_input}")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                last_turns = user_history[-4:]
         | 
| 24 | 
            +
                prompt = "\n".join(last_turns) + "\nAI:"
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                inputs = tokenizer(prompt, return_tensors="pt").to(device)
         | 
| 27 | 
            +
                output_ids = model.generate(
         | 
| 28 | 
            +
                    **inputs,
         | 
| 29 | 
            +
                    max_new_tokens=50,
         | 
| 30 | 
            +
                    do_sample=False,
         | 
| 31 | 
            +
                    pad_token_id=tokenizer.eos_token_id,
         | 
| 32 | 
            +
                    eos_token_id=tokenizer.eos_token_id,
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
                response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
         | 
| 35 | 
            +
                response = response_text.split("AI:")[-1].strip()
         | 
| 36 |  | 
| 37 | 
            +
                user_history.append(f"AI: {response}")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                depression_prob = None
         | 
| 40 | 
            +
                if turn_counter == MAX_TURNS_FOR_PREDICTION:
         | 
| 41 | 
            +
                    prediction_prompt = (
         | 
| 42 | 
            +
                        "\n".join(user_history[-8:]) +
         | 
| 43 | 
            +
                        "\nAI: Based on this conversation, what is the probability that the human has depression? "
         | 
| 44 | 
            +
                        "Please answer with a number between 0 and 1."
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
                    inputs_pred = tokenizer(prediction_prompt, return_tensors="pt").to(device)
         | 
| 47 | 
            +
                    output_pred_ids = model.generate(
         | 
| 48 | 
            +
                        **inputs_pred,
         | 
| 49 | 
            +
                        max_new_tokens=10,
         | 
| 50 | 
            +
                        do_sample=False,
         | 
| 51 | 
            +
                        pad_token_id=tokenizer.eos_token_id,
         | 
| 52 | 
            +
                        eos_token_id=tokenizer.eos_token_id,
         | 
|  | |
|  | |
|  | |
| 53 | 
             
                    )
         | 
| 54 | 
            +
                    pred_text = tokenizer.decode(output_pred_ids[0], skip_special_tokens=True)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    match = re.search(r"0?\.\d+", pred_text)
         | 
| 57 | 
            +
                    if match:
         | 
| 58 | 
            +
                        try:
         | 
| 59 | 
            +
                            depression_prob = float(match.group(0))
         | 
| 60 | 
            +
                        except:
         | 
| 61 | 
            +
                            depression_prob = None
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                return response, depression_prob if depression_prob is not None else "Prediction after 8 turns"
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            iface = gr.Interface(
         | 
| 66 | 
            +
                fn=chat,
         | 
| 67 | 
            +
                inputs=gr.Textbox(lines=2, label="Your Message"),
         | 
| 68 | 
            +
                outputs=[gr.Textbox(label="AI Response"), gr.Textbox(label="Depression Probability")],
         | 
| 69 | 
            +
                title="Depression Detection Chatbot",
         | 
| 70 | 
            +
                description="Chat with the AI. After 8 turns it predicts depression probability."
         | 
| 71 | 
            +
            )
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            if __name__ == "__main__":
         | 
| 74 | 
            +
                iface.launch()
         |