lakhera2023 commited on
Commit
cdcb203
·
verified ·
1 Parent(s): 48c756e

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +101 -0
predict.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimized prediction script for Hugging Face Inference Endpoints
2
+ # This version uses less memory and is optimized for smaller instances
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from cog import BasePredictor, Input
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class Predictor(BasePredictor):
14
+ def setup(self) -> None:
15
+ """Load the DevOps SLM model into memory with optimizations"""
16
+ logger.info("Loading DevOps SLM model with memory optimizations...")
17
+
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ logger.info(f"Using device: {self.device}")
20
+
21
+ # Load model with memory optimizations
22
+ self.model = AutoModelForCausalLM.from_pretrained(
23
+ "lakhera2023/devops-slm",
24
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
25
+ device_map="auto" if self.device == "cuda" else None,
26
+ low_cpu_mem_usage=True,
27
+ trust_remote_code=True,
28
+ # Memory optimizations
29
+ use_cache=False, # Disable KV cache to save memory
30
+ attn_implementation="eager" # Use eager attention (less memory)
31
+ )
32
+
33
+ # Load tokenizer
34
+ self.tokenizer = AutoTokenizer.from_pretrained("lakhera2023/devops-slm")
35
+
36
+ # Set pad token
37
+ if self.tokenizer.pad_token is None:
38
+ self.tokenizer.pad_token = self.tokenizer.eos_token
39
+
40
+ # Clear cache
41
+ if torch.cuda.is_available():
42
+ torch.cuda.empty_cache()
43
+
44
+ logger.info("DevOps SLM model loaded successfully with optimizations!")
45
+
46
+ def predict(
47
+ self,
48
+ prompt: str = Input(description="DevOps question or task prompt"),
49
+ max_tokens: int = Input(description="Maximum number of tokens to generate", default=150, ge=1, le=500),
50
+ temperature: float = Input(description="Sampling temperature", default=0.7, ge=0.1, le=2.0),
51
+ top_p: float = Input(description="Top-p sampling parameter", default=0.9, ge=0.1, le=1.0),
52
+ top_k: int = Input(description="Top-k sampling parameter", default=50, ge=1, le=100),
53
+ ) -> str:
54
+ """Generate DevOps response using the specialized model"""
55
+ try:
56
+ logger.info(f"Generating response for prompt: {prompt[:100]}...")
57
+
58
+ # Tokenize input with truncation to save memory
59
+ inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, max_length=256).to(self.device)
60
+
61
+ # Generate response with memory optimizations
62
+ with torch.no_grad():
63
+ outputs = self.model.generate(
64
+ **inputs,
65
+ max_new_tokens=max_tokens,
66
+ temperature=temperature,
67
+ do_sample=True,
68
+ top_p=top_p,
69
+ top_k=top_k,
70
+ pad_token_id=self.tokenizer.eos_token_id,
71
+ eos_token_id=self.tokenizer.eos_token_id,
72
+ repetition_penalty=1.1,
73
+ no_repeat_ngram_size=2,
74
+ early_stopping=True, # Stop early to save computation
75
+ use_cache=False, # Don't use KV cache
76
+ output_attentions=False, # Don't output attention weights
77
+ output_hidden_states=False # Don't output hidden states
78
+ )
79
+
80
+ # Decode response
81
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
82
+
83
+ # Extract only the generated part
84
+ if prompt in full_response:
85
+ response = full_response.split(prompt)[-1].strip()
86
+ else:
87
+ response = full_response.strip()
88
+
89
+ # Clean up template artifacts
90
+ response = response.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
91
+
92
+ # Clear cache after generation
93
+ if torch.cuda.is_available():
94
+ torch.cuda.empty_cache()
95
+
96
+ logger.info(f"Generated response length: {len(response)}")
97
+ return response
98
+
99
+ except Exception as e:
100
+ logger.error(f"Error generating response: {e}")
101
+ return f"Error: {str(e)}"