zhu-mingye commited on
Commit
71956f9
·
1 Parent(s): 71cbbc3
Files changed (2) hide show
  1. app.py +204 -4
  2. requirements.txt +4 -0
app.py CHANGED
@@ -1,7 +1,207 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import re
2
+ from typing import Dict, Any, List
3
+
4
  import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ MODEL_ID = "microsoft/unixcoder-base-nine"
10
+ MAX_TOKENS = 512
11
+
12
+
13
+ def _safe_float(v: float, ndigits: int = 4) -> float:
14
+ return float(round(float(v), ndigits))
15
+
16
+
17
+ class UniXcoderAnalyzer:
18
+ def __init__(self, model_id: str = MODEL_ID):
19
+ self.model_id = model_id
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ self.model = AutoModel.from_pretrained(model_id).to(self.device)
23
+ self.model.eval()
24
+
25
+ @torch.no_grad()
26
+ def _embed(self, text: str) -> np.ndarray:
27
+ encoded = self.tokenizer(
28
+ text,
29
+ return_tensors="pt",
30
+ truncation=True,
31
+ max_length=MAX_TOKENS,
32
+ padding=True,
33
+ )
34
+ encoded = {k: v.to(self.device) for k, v in encoded.items()}
35
+ outputs = self.model(**encoded)
36
+
37
+ token_embeddings = outputs.last_hidden_state
38
+ attention_mask = encoded["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
39
+ masked = token_embeddings * attention_mask
40
+ pooled = masked.sum(dim=1) / torch.clamp(attention_mask.sum(dim=1), min=1e-9)
41
+ vec = pooled[0].detach().cpu().numpy()
42
+ norm = np.linalg.norm(vec) + 1e-9
43
+ return vec / norm
44
+
45
+ def analyze(self, prompt: str, language: str, code: str, analysis_type: str) -> Dict[str, Any]:
46
+ prompt = (prompt or "").strip()
47
+ code = (code or "").strip()
48
+
49
+ if not code:
50
+ return {
51
+ "modelStrategy": "unixcoder-hf-space",
52
+ "enabled": True,
53
+ "model": self.model_id,
54
+ "status": "error",
55
+ "message": "code 不能为空",
56
+ "analysisError": "EMPTY_CODE",
57
+ "summary": "未提供待分析代码,无法执行语义分析。",
58
+ "keyPoints": [],
59
+ "risks": ["输入代码为空"],
60
+ "suggestions": ["请传入完整代码片段后重试"]
61
+ }
62
+
63
+ prompt_vec = self._embed(prompt if prompt else f"Analyze {language} code")
64
+ code_vec = self._embed(code)
65
+
66
+ semantic_alignment = float(np.dot(prompt_vec, code_vec))
67
+ semantic_alignment = (semantic_alignment + 1.0) / 2.0
68
+
69
+ lines = [ln for ln in code.splitlines() if ln.strip()]
70
+ line_count = len(lines)
71
+ char_count = len(code)
72
+ function_like = len(re.findall(r"\b(def|function|public|private|protected|class)\b", code))
73
+ control_flow = len(re.findall(r"\b(if|else|for|while|switch|try|catch)\b", code))
74
+ long_lines = sum(1 for ln in lines if len(ln) > 120)
75
+ comments = len(re.findall(r"//|/\*|\*/|#", code))
76
+
77
+ complexity_score = min(1.0, (control_flow * 0.08) + (function_like * 0.05) + (line_count / 300.0))
78
+ maintainability = max(0.0, min(1.0, 1.0 - (long_lines / max(1, line_count)) * 0.7 + min(comments / max(1, line_count), 0.2)))
79
+
80
+ key_points: List[str] = [
81
+ f"检测到约 {line_count} 行有效代码,{function_like} 个函数/类相关声明。",
82
+ f"语义相关性得分 {semantic_alignment:.2f}(0-1 越高越贴合需求)。",
83
+ f"控制流关键字出现 {control_flow} 次,复杂度评分 {complexity_score:.2f}。",
84
+ ]
85
+
86
+ risks: List[str] = []
87
+ if semantic_alignment < 0.55:
88
+ risks.append("代码与需求语义相似度偏低,可能存在功能偏移。")
89
+ if long_lines > 0:
90
+ risks.append(f"存在 {long_lines} 行超长代码行,可读性和可维护性风险较高。")
91
+ if comments == 0:
92
+ risks.append("未检测到注释,后续维护和协作成本可能上升。")
93
+ if complexity_score > 0.7:
94
+ risks.append("控制流较复杂,建议补充单元测试覆盖核心分支。")
95
+
96
+ if not risks:
97
+ risks.append("未发现明显高风险项,建议结合业务规则进行人工复核。")
98
+
99
+ suggestions: List[str] = [
100
+ "对关键逻辑分支补充单元测试,优先覆盖边界输入。",
101
+ "将超过 120 字符的长行拆分,提升可读性。",
102
+ "为核心函数补充文档注释,标明输入、输出和异常行为。",
103
+ ]
104
+
105
+ if analysis_type == "risk":
106
+ summary = (
107
+ f"风险导向分析完成:复杂度 {complexity_score:.2f},可维护性 {maintainability:.2f},"
108
+ f"语义相关性 {semantic_alignment:.2f}。"
109
+ )
110
+ elif analysis_type == "quality":
111
+ summary = (
112
+ f"质量导向分析完成:代码规模 {line_count} 行,复杂度 {complexity_score:.2f},"
113
+ f"可维护性 {maintainability:.2f}。"
114
+ )
115
+ else:
116
+ summary = (
117
+ f"语义分析完成:代码与需求相关性 {semantic_alignment:.2f},"
118
+ f"复杂度 {complexity_score:.2f},可维护性 {maintainability:.2f}。"
119
+ )
120
+
121
+ return {
122
+ "modelStrategy": "unixcoder-hf-space",
123
+ "enabled": True,
124
+ "model": self.model_id,
125
+ "status": "ok",
126
+ "message": "analysis success",
127
+ "analysisError": None,
128
+ "summary": summary,
129
+ "keyPoints": key_points,
130
+ "risks": risks,
131
+ "suggestions": suggestions,
132
+ "scores": {
133
+ "semanticAlignment": _safe_float(semantic_alignment),
134
+ "complexity": _safe_float(complexity_score),
135
+ "maintainability": _safe_float(maintainability),
136
+ "lineCount": line_count,
137
+ "charCount": char_count,
138
+ },
139
+ "meta": {
140
+ "language": language,
141
+ "analysisType": analysis_type,
142
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
143
+ },
144
+ }
145
+
146
+
147
+ analyzer = UniXcoderAnalyzer()
148
+
149
+
150
+ def analyze_for_ui(prompt: str, language: str, code: str, analysis_type: str):
151
+ result = analyzer.analyze(prompt=prompt, language=language, code=code, analysis_type=analysis_type)
152
+ md = "\n".join(
153
+ [
154
+ f"### 分析摘要\n{result.get('summary', '')}",
155
+ "### Key Points",
156
+ "\n".join([f"- {x}" for x in result.get("keyPoints", [])]) or "- 无",
157
+ "### Risks",
158
+ "\n".join([f"- {x}" for x in result.get("risks", [])]) or "- 无",
159
+ "### Suggestions",
160
+ "\n".join([f"- {x}" for x in result.get("suggestions", [])]) or "- 无",
161
+ ]
162
+ )
163
+ return result, md
164
+
165
+
166
+ with gr.Blocks(title="UniXcoder Code Analyzer") as demo:
167
+ gr.Markdown("# UniXcoder 代码理解与分析服务")
168
+ gr.Markdown("用于代码语义理解、风险提示和质量建议。可通过页面交互,也可通过 Gradio API 调用。")
169
+
170
+ with gr.Row():
171
+ language = gr.Dropdown(
172
+ choices=["java", "python", "javascript", "cpp", "go", "other"],
173
+ value="java",
174
+ label="Language",
175
+ )
176
+ analysis_type = gr.Dropdown(
177
+ choices=["summary", "risk", "quality"],
178
+ value="summary",
179
+ label="Analysis Type",
180
+ )
181
+
182
+ prompt = gr.Textbox(
183
+ label="需求描述 (Prompt)",
184
+ placeholder="例如:检查这段代码是否满足线程安全和异常处理要求",
185
+ lines=3,
186
+ )
187
+ code = gr.Textbox(
188
+ label="待分析代码",
189
+ placeholder="在这里粘贴代码...",
190
+ lines=16,
191
+ )
192
+
193
+ run_btn = gr.Button("开始分析", variant="primary")
194
+
195
+ output_json = gr.JSON(label="结构化结果(用于后端API接入)")
196
+ output_md = gr.Markdown(label="可读报告")
197
+
198
+ run_btn.click(
199
+ fn=analyze_for_ui,
200
+ inputs=[prompt, language, code, analysis_type],
201
+ outputs=[output_json, output_md],
202
+ api_name="analyze",
203
+ )
204
 
 
 
205
 
206
+ if __name__ == "__main__":
207
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.44.1,<5.0.0
2
+ transformers>=4.40.0
3
+ torch>=2.2.0
4
+ numpy>=1.26.0