leavoigt commited on
Commit
287959e
·
1 Parent(s): 79ad53d
Files changed (8) hide show
  1. Dockerfile +23 -0
  2. README.md +56 -9
  3. app.py +38 -4
  4. gitignore +1 -0
  5. params.cfg +35 -0
  6. requirements.txt +19 -0
  7. utils/generator.py +224 -0
  8. utils/utils.py +46 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------- base image --------
2
+ FROM python:3.11-slim
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ OMP_NUM_THREADS=1 \
6
+ TOKENIZERS_PARALLELISM=false
7
+ #GRADIO_MCP_SERVER=True
8
+
9
+ # -------- install deps --------
10
+ WORKDIR /app
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # -------- copy source --------
15
+ COPY app.py
16
+ COPY params.cfg .
17
+ COPY .env* ./
18
+
19
+ # Ports:
20
+ # • 7860 → Gradio UI (HF Spaces standard)
21
+ EXPOSE 7860
22
+
23
+ CMD ["python", "-m", "app.py"]
README.md CHANGED
@@ -1,13 +1,60 @@
1
  ---
2
- title: Chatfed Generator
3
- emoji: 👁
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ChatFed Generator
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
 
 
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # ChatFed Generator - MCP Server
12
+
13
+ A language model-based generation service designed for ChatFed RAG (Retrieval-Augmented Generation) pipelines. This module serves as an **MCP (Model Context Protocol) server** that generates contextual responses using configurable LLM providers with support for retrieval result processing.
14
+
15
+ ## MCP Endpoint
16
+
17
+ The main MCP function is `generate` which provides context-aware text generation using configurable LLM providers when properly configured with API credentials.
18
+
19
+ **Parameters**:
20
+ - `query` (str, required): The question or query to be answered
21
+ - `context` (str|list, required): Context for answering - can be plain text or list of retrieval result dictionaries
22
+
23
+ **Returns**: String containing the generated answer based on the provided context and query.
24
+
25
+ **Example usage**:
26
+ ```python
27
+ from gradio_client import Client
28
+
29
+ client = Client("ENTER CONTAINER URL / SPACE ID")
30
+ result = client.predict(
31
+ query="What are the key findings?",
32
+ context="Your relevant documents or context here...",
33
+ api_name="/generate"
34
+ )
35
+ print(result)
36
+ ```
37
+
38
+ ## Configuration
39
+
40
+ ### LLM Provider Configuration
41
+ 1. Set your preferred inference provider in `params.cfg`
42
+ 2. Configure the model and generation parameters
43
+ 3. Set the required API key environment variable
44
+ 4. [Optional] Adjust temperature and max_tokens settings
45
+ 5. Run the app:
46
+
47
+ ```bash
48
+ docker build -t chatfed-generator .
49
+ docker run -p 7860:7860 chatfed-generator
50
+ ```
51
+
52
+ ## Environment Variables Required
53
+
54
+ # Make sure to set the appropriate environment variables:
55
+ # - OpenAI: `OPENAI_API_KEY`
56
+ # - Anthropic: `ANTHROPIC_API_KEY`
57
+ # - Cohere: `COHERE_API_KEY`
58
+ # - HuggingFace: `HF_TOKEN`
59
+
60
+
app.py CHANGED
@@ -1,7 +1,41 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
  import gradio as gr
2
+ from .generator import generate
3
 
4
+ # ---------------------------------------------------------------------
5
+ # Gradio Interface with MCP support
6
+ # ---------------------------------------------------------------------
7
+ ui = gr.Interface(
8
+ fn=generate,
9
+ inputs=[
10
+ gr.Textbox(
11
+ label="Query",
12
+ lines=2,
13
+ placeholder="Enter query here",
14
+ info="The query to search for in the vector database"
15
+ ),
16
+ gr.Textbox(
17
+ label="Context",
18
+ lines=8,
19
+ placeholder="Paste relevant context here",
20
+ info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
21
+ ),
22
+ ],
23
+ outputs=gr.Textbox(
24
+ label="Generated Answer",
25
+ lines=6,
26
+ show_copy_button=True
27
+ ),
28
+ title="ChatFed Generation Module",
29
+ description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
30
+ api_name="generate"
31
+ )
32
+
33
+ # Launch with MCP server enabled
34
+ if __name__ == "__main__":
35
+ ui.launch(
36
+ server_name="0.0.0.0",
37
+ server_port=7860,
38
+ mcp_server=True,
39
+ show_error=True
40
+ )
41
 
 
 
gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
params.cfg ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [generator]
2
+ PROVIDER = huggingface
3
+ MODEL = meta-llama/Meta-Llama-3-8B-Instruct
4
+ MAX_TOKENS = 512
5
+ TEMPERATURE = 0.2
6
+
7
+ # OpenAI
8
+ # [generator]
9
+ # PROVIDER = openai
10
+ # MODEL = gpt-4o
11
+ # MAX_TOKENS = 512
12
+ # TEMPERATURE = 0.2
13
+
14
+ ## Anthropic
15
+ # [generator]
16
+ # PROVIDER = anthropic
17
+ # MODEL = claude-3-haiku-20240307
18
+ # MAX_TOKENS = 512
19
+ # TEMPERATURE = 0.2
20
+
21
+ ## Cohere
22
+ # [generator]
23
+ # PROVIDER = cohere
24
+ # MODEL = command
25
+ # MAX_TOKENS = 512
26
+ # TEMPERATURE = 0.2
27
+
28
+
29
+ ## Environment Variables Required
30
+
31
+ # Make sure to set the appropriate environment variables:
32
+ # - OpenAI: `OPENAI_API_KEY`
33
+ # - Anthropic: `ANTHROPIC_API_KEY`
34
+ # - Cohere: `COHERE_API_KEY`
35
+ # - HuggingFace: `HF_TOKEN`
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0
3
+ gradio[mcp]
4
+ python-dotenv>=1.0.0
5
+
6
+ # LangChain core
7
+ langchain-core>=0.1.0
8
+ langchain-community>=0.0.1
9
+
10
+ # Provider-specific LangChain packages
11
+ langchain-openai>=0.1.0
12
+ langchain-anthropic>=0.1.0
13
+ langchain-cohere>=0.1.0
14
+ langchain-together>=0.1.0
15
+ langchain-huggingface>=0.0.1
16
+
17
+ # Additional dependencies that might be needed
18
+ requests>=2.31.0
19
+ pydantic>=2.0.0
utils/generator.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ import json
4
+ import ast
5
+ from typing import List, Dict, Any, Union
6
+ from dotenv import load_dotenv
7
+
8
+ # LangChain imports
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_anthropic import ChatAnthropic
11
+ from langchain_cohere import ChatCohere
12
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+
15
+ # Local imports
16
+ from .utils import getconfig, get_auth
17
+
18
+ # ---------------------------------------------------------------------
19
+ # Model / client initialization (non exaustive list of providers)
20
+ # ---------------------------------------------------------------------
21
+ config = getconfig("params.cfg")
22
+
23
+ PROVIDER = config.get("generator", "PROVIDER")
24
+ MODEL = config.get("generator", "MODEL")
25
+ MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
26
+ TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
27
+
28
+ # Set up authentication for the selected provider
29
+ auth_config = get_auth(PROVIDER)
30
+
31
+ def get_chat_model():
32
+ """Initialize the appropriate LangChain chat model based on provider"""
33
+ common_params = {
34
+ "temperature": TEMPERATURE,
35
+ "max_tokens": MAX_TOKENS,
36
+ }
37
+
38
+ if PROVIDER == "openai":
39
+ return ChatOpenAI(
40
+ model=MODEL,
41
+ openai_api_key=auth_config["api_key"],
42
+ **common_params
43
+ )
44
+ elif PROVIDER == "anthropic":
45
+ return ChatAnthropic(
46
+ model=MODEL,
47
+ anthropic_api_key=auth_config["api_key"],
48
+ **common_params
49
+ )
50
+ elif PROVIDER == "cohere":
51
+ return ChatCohere(
52
+ model=MODEL,
53
+ cohere_api_key=auth_config["api_key"],
54
+ **common_params
55
+ )
56
+ elif PROVIDER == "huggingface":
57
+ # Initialize HuggingFaceEndpoint with explicit parameters
58
+ llm = HuggingFaceEndpoint(
59
+ repo_id=MODEL,
60
+ huggingfacehub_api_token=auth_config["api_key"],
61
+ task="text-generation",
62
+ temperature=TEMPERATURE,
63
+ max_new_tokens=MAX_TOKENS
64
+ )
65
+ return ChatHuggingFace(llm=llm)
66
+ else:
67
+ raise ValueError(f"Unsupported provider: {PROVIDER}")
68
+
69
+ # Initialize provider-agnostic chat model
70
+ chat_model = get_chat_model()
71
+
72
+ # ---------------------------------------------------------------------
73
+ # Context processing - may need further refinement (i.e. to manage other data sources)
74
+ # ---------------------------------------------------------------------
75
+ def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
76
+ """
77
+ Extract only relevant fields from retrieval results.
78
+
79
+ Args:
80
+ retrieval_results: List of JSON objects from retriever
81
+
82
+ Returns:
83
+ List of processed objects with only relevant fields
84
+ """
85
+
86
+ retrieval_results = ast.literal_eval(retrieval_results)
87
+
88
+ processed_results = []
89
+
90
+ for result in retrieval_results:
91
+ # Extract the answer content
92
+ answer = result.get('answer', '')
93
+
94
+ # Extract document identification from metadata
95
+ metadata = result.get('answer_metadata', {})
96
+ doc_info = {
97
+ 'answer': answer,
98
+ 'filename': metadata.get('filename', 'Unknown'),
99
+ 'page': metadata.get('page', 'Unknown'),
100
+ 'year': metadata.get('year', 'Unknown'),
101
+ 'source': metadata.get('source', 'Unknown'),
102
+ 'document_id': metadata.get('_id', 'Unknown')
103
+ }
104
+
105
+ processed_results.append(doc_info)
106
+
107
+ return processed_results
108
+
109
+ def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
110
+ """
111
+ Format processed retrieval results into a context string for the LLM.
112
+
113
+ Args:
114
+ processed_results: List of processed objects with relevant fields
115
+
116
+ Returns:
117
+ Formatted context string
118
+ """
119
+ if not processed_results:
120
+ return ""
121
+
122
+ context_parts = []
123
+
124
+ for i, result in enumerate(processed_results, 1):
125
+ doc_reference = f"[Document {i}: {result['filename']}"
126
+ if result['page'] != 'Unknown':
127
+ doc_reference += f", Page {result['page']}"
128
+ if result['year'] != 'Unknown':
129
+ doc_reference += f", Year {result['year']}"
130
+ doc_reference += "]"
131
+
132
+ context_part = f"{doc_reference}\n{result['answer']}\n"
133
+ context_parts.append(context_part)
134
+
135
+ return "\n".join(context_parts)
136
+
137
+ # ---------------------------------------------------------------------
138
+ # Core generation function for both Gradio UI and MCP
139
+ # ---------------------------------------------------------------------
140
+ async def _call_llm(messages: list) -> str:
141
+ """
142
+ Provider-agnostic LLM call using LangChain.
143
+
144
+ Args:
145
+ messages: List of LangChain message objects
146
+
147
+ Returns:
148
+ Generated response content as string
149
+ """
150
+ try:
151
+ # Use async invoke for better performance
152
+ response = await chat_model.ainvoke(messages)
153
+ return response.content.strip()
154
+ except Exception as e:
155
+ logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
156
+ raise
157
+
158
+ def build_messages(question: str, context: str) -> list:
159
+ """
160
+ Build messages in LangChain format.
161
+
162
+ Args:
163
+ question: The user's question
164
+ context: The relevant context for answering
165
+
166
+ Returns:
167
+ List of LangChain message objects
168
+ """
169
+ system_content = (
170
+ "You are an expert assistant. Answer the USER question using only the "
171
+ "CONTEXT provided. If the context is insufficient say 'I don't know.'"
172
+ )
173
+
174
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
175
+
176
+ return [
177
+ SystemMessage(content=system_content),
178
+ HumanMessage(content=user_content)
179
+ ]
180
+
181
+
182
+ async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
183
+ """
184
+ Generate an answer to a query using provided context through RAG.
185
+
186
+ This function takes a user query and relevant context, then uses a language model
187
+ to generate a comprehensive answer based on the provided information.
188
+
189
+ Args:
190
+ query (str): User query
191
+ context (list): List of retrieval result objects (dictionaries)
192
+ Returns:
193
+ str: The generated answer based on the query and context
194
+ """
195
+ if not query.strip():
196
+ return "Error: Query cannot be empty"
197
+
198
+ # Handle both string context (for Gradio UI) and list context (from retriever)
199
+ if isinstance(context, list):
200
+ if not context:
201
+ return "Error: No retrieval results provided"
202
+
203
+ # Process the retrieval results
204
+ processed_results = extract_relevant_fields(context)
205
+ formatted_context = format_context_from_results(processed_results)
206
+
207
+ if not formatted_context.strip():
208
+ return "Error: No valid content found in retrieval results"
209
+
210
+ elif isinstance(context, str):
211
+ if not context.strip():
212
+ return "Error: Context cannot be empty"
213
+ formatted_context = context
214
+
215
+ else:
216
+ return "Error: Context must be either a string or list of retrieval results"
217
+
218
+ try:
219
+ messages = build_messages(query, formatted_context)
220
+ answer = await _call_llm(messages)
221
+ return answer
222
+ except Exception as e:
223
+ logging.exception("Generation failed")
224
+ return f"Error: {str(e)}"
utils/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import configparser
3
+ import logging
4
+ from dotenv import load_dotenv
5
+
6
+
7
+
8
+ # Local .env file
9
+ load_dotenv()
10
+
11
+ def getconfig(configfile_path: str):
12
+ """
13
+ Read the config file
14
+ Params
15
+ ----------------
16
+ configfile_path: file path of .cfg file
17
+ """
18
+ config = configparser.ConfigParser()
19
+ try:
20
+ config.read_file(open(configfile_path))
21
+ return config
22
+ except:
23
+ logging.warning("config file not found")
24
+
25
+ # ---------------------------------------------------------------------
26
+ # Provider-agnostic authentication and configuration
27
+ # ---------------------------------------------------------------------
28
+ def get_auth(provider: str) -> dict:
29
+ """Get authentication configuration for different providers"""
30
+ auth_configs = {
31
+ "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
32
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
33
+ "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
34
+ "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
35
+ }
36
+
37
+ if provider not in auth_configs:
38
+ raise ValueError(f"Unsupported provider: {provider}")
39
+
40
+ auth_config = auth_configs[provider]
41
+ api_key = auth_config.get("api_key")
42
+
43
+ if not api_key:
44
+ raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
45
+
46
+ return auth_config