|
|
| """CUDA-optimized chat interface for Ursa Minor Smashed model"""
|
|
|
| import torch
|
| from inference_cuda import generate_direct, load_model_direct
|
|
|
| def main():
|
| print("Ursa Minor Smashed Chat (CUDA)")
|
| print("Type 'quit' to exit, 'reset' to clear context")
|
| print("-" * 50)
|
|
|
| if not torch.cuda.is_available():
|
| print("ERROR: CUDA is not available. Use chat_cpu.py for CPU inference.")
|
| return
|
|
|
|
|
| print("Loading model on CUDA...")
|
| model = load_model_direct("model_optimized.pt")
|
| print("Model loaded! Ready to chat.\n")
|
|
|
| context = ""
|
| max_context_length = 800
|
|
|
| while True:
|
| user_input = input("You: ").strip()
|
|
|
| if user_input.lower() == 'quit':
|
| print("Goodbye!")
|
| break
|
| elif user_input.lower() == 'reset':
|
| context = ""
|
| print("Context cleared!")
|
| continue
|
| elif user_input == "":
|
| continue
|
|
|
|
|
| if context:
|
| context += f"\nHuman: {user_input}\nAssistant:"
|
| else:
|
| context = f"Human: {user_input}\nAssistant:"
|
|
|
|
|
| if len(context.split()) > max_context_length:
|
|
|
| words = context.split()
|
| context = " ".join(words[-max_context_length:])
|
|
|
|
|
| try:
|
| full_response = generate_direct(
|
| model,
|
| context,
|
| max_new_tokens=100,
|
| temperature=0.8,
|
| top_p=0.9,
|
| top_k=50,
|
| repetition_penalty=1.1
|
| )
|
|
|
|
|
| response = full_response[len(context):].strip()
|
|
|
|
|
| if "Human:" in response:
|
| response = response.split("Human:")[0].strip()
|
|
|
| print(f"Assistant: {response}")
|
|
|
|
|
| context = full_response
|
|
|
| except Exception as e:
|
| print(f"Error generating response: {e}")
|
| print("Try typing 'reset' to clear context and continue.")
|
|
|
| if __name__ == "__main__":
|
| main() |