ESSmith.tech commited on
Commit
c121d7b
·
unverified ·
1 Parent(s): a462f84

Update model_manager.py

Browse files
Files changed (1) hide show
  1. src/model_manager.py +17 -7
src/model_manager.py CHANGED
@@ -50,13 +50,21 @@ class LocalModel(ModelInterface):
50
 
51
  def generate(self, messages: List[Dict[str, str]], max_tokens: int = 512,
52
  temperature: float = 0.7, top_p: float = 0.9, **kwargs) -> Generator[str, None, None]:
53
- """Generate response from local model"""
54
  if not self._ready:
55
  raise RuntimeError("Model not ready")
56
-
57
- # Build prompt as plain text
58
- prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
59
-
 
 
 
 
 
 
 
 
60
  outputs = self.pipe(
61
  prompt,
62
  max_new_tokens=max_tokens,
@@ -65,9 +73,11 @@ class LocalModel(ModelInterface):
65
  top_p=top_p,
66
  return_full_text=False
67
  )
68
-
69
  response = outputs[0]["generated_text"]
70
- yield response.strip()
 
 
71
 
72
  class APIModel(ModelInterface):
73
  """API model implementation using HuggingFace Inference Client"""
 
50
 
51
  def generate(self, messages: List[Dict[str, str]], max_tokens: int = 512,
52
  temperature: float = 0.7, top_p: float = 0.9, **kwargs) -> Generator[str, None, None]:
53
+ """Generate response from local model (single-turn, avoids self-conversation)"""
54
  if not self._ready:
55
  raise RuntimeError("Model not ready")
56
+
57
+ # Use only system prompt and latest user message for local model
58
+ system_msg = next((m['content'] for m in messages if m['role'] == 'system'), "")
59
+ user_msg = next((m['content'] for m in reversed(messages) if m['role'] == 'user'), "")
60
+ # Use the selected philosopher's name if present, else 'assistant'
61
+ assistant_name = "assistant"
62
+ for m in messages:
63
+ if m['role'] not in ('system', 'user'):
64
+ assistant_name = m['role']
65
+ break
66
+ prompt = f"system: {system_msg}\nuser: {user_msg}\n{assistant_name}:"
67
+
68
  outputs = self.pipe(
69
  prompt,
70
  max_new_tokens=max_tokens,
 
73
  top_p=top_p,
74
  return_full_text=False
75
  )
76
+
77
  response = outputs[0]["generated_text"]
78
+ # Only return the first line (up to next newline or end)
79
+ first_line = response.strip().split("\n")[0]
80
+ yield first_line
81
 
82
  class APIModel(ModelInterface):
83
  """API model implementation using HuggingFace Inference Client"""