关键词

| 数据清洗 | 启发式规则 | 模型清洗 | 去重技术 | 质量评分 | 毒性过滤 | 数据分布 | MinHash | SimHash | 文本质量 |


一、启发式规则清洗

1.1 规则清洗概述

启发式规则清洗是数据清洗的基础环节,通过预定义的规则集合对原始数据进行快速过滤。这种方法的核心优势在于:

  • 高效率:规则匹配的计算复杂度低,适合大规模数据处理
  • 可解释性强:每条数据被过滤的原因清晰可追溯
  • 成本低廉:无需训练模型,仅依赖规则执行

规则清洗的局限性

启发式规则难以捕捉语言的细微差别和上下文依赖性,可能产生假阳性(误杀)或假阴性(漏放)。最佳实践是将规则清洗作为初筛,后续配合更精细的模型过滤。

1.2 基础文本质量规则

长度与格式规则

class TextQualityRules:
    """文本质量规则集"""
    
    def __init__(self):
        self.rules = []
        
    def add_length_rule(self, min_chars=10, max_chars=50000):
        """添加长度过滤规则"""
        self.rules.append({
            "name": "length_filter",
            "check": lambda text: min_chars <= len(text) <= max_chars,
            "action": "reject" if "reject" else "flag"
        })
        
    def add_language_ratio_rule(self, min_chinese_ratio=0.3):
        """添加语言比例规则(针对中文语料)"""
        def check_chinese_ratio(text):
            chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
            return chinese_chars / len(text) >= min_chinese_ratio if text else False
            
        self.rules.append({
            "name": "chinese_ratio",
            "check": check_chinese_ratio,
            "action": "reject"
        })
        
    def add_repetition_rule(self, max_word_repetition_ratio=0.3):
        """添加重复内容过滤规则"""
        def check_repetition(text):
            words = text.split()
            if len(words) < 10:
                return True
                
            unique_words = len(set(words))
            repetition_ratio = 1 - unique_words / len(words)
            return repetition_ratio <= max_word_repetition_ratio
            
        self.rules.append({
            "name": "repetition_filter",
            "check": check_repetition,
            "action": "flag"
        })
        
    def add_special_char_ratio(self, max_special_ratio=0.5):
        """添加特殊字符比例规则"""
        def check_special_chars(text):
            special = sum(1 for c in text if not c.isalnum() and not '\u4e00' <= c <= '\u9fff')
            return special / len(text) <= max_special_ratio if text else True
            
        self.rules.append({
            "name": "special_char_ratio",
            "check": check_special_chars,
            "action": "flag"
        })
        
    def apply_rules(self, text):
        """应用所有规则"""
        results = []
        for rule in self.rules:
            try:
                passed = rule["check"](text)
                if not passed:
                    results.append({
                        "rule": rule["name"],
                        "passed": False,
                        "action": rule["action"]
                    })
            except Exception as e:
                results.append({
                    "rule": rule["name"],
                    "error": str(e),
                    "action": "flag"
                })
                
        return {
            "text": text,
            "passed": all(r["passed"] for r in results),
            "issues": results,
            "action": self._determine_action(results)
        }
        
    def _determine_action(self, results):
        """根据规则违反情况决定处理动作"""
        if any(r["action"] == "reject" and not r["passed"] for r in results):
            return "reject"
        elif any(r["action"] == "flag" and not r["passed"] for r in results):
            return "flag_for_review"
        return "accept"

内容安全规则

class ContentSafetyRules:
    """内容安全规则集"""
    
    def __init__(self):
        self.blocked_patterns = self._load_blocked_patterns()
        self.spam_indicators = self._load_spam_indicators()
        
    def _load_blocked_patterns(self):
        """加载禁止模式库"""
        return {
            "contact_info": [
                r'\d{11}',  # 手机号
                r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',  # 邮箱
                r'[信\s]*[::]?\s*[a-zA-Z0-9_]{6,20}',  # 微信号
            ],
            "explicit_content": [
                r'色情|赌博|毒品',
                r'[颜色网站]+',
            ],
            "personal_info": [
                r'身份证号[::]\d{15}|\d{18}',
                r'银行卡[::]\d{16,19}',
            ]
        }
        
    def _load_spam_indicators(self):
        """加载垃圾信息指标"""
        return {
            "excessive_caps_ratio": 0.5,  # 大写字母过多
            "exclamation_count": 5,  # 感叹号过多
            "url_count": 10,  # URL过多
            "repeated_punctuation": 5,  # 重复标点过多
        }
        
    def check_contact_info(self, text):
        """检测联系方式"""
        found = []
        for pattern in self.blocked_patterns["contact_info"]:
            import re
            matches = re.findall(pattern, text)
            if matches:
                found.append({
                    "type": "contact_info",
                    "pattern": pattern,
                    "matches": matches
                })
        return found
        
    def check_spam_indicators(self, text):
        """检测垃圾信息指标"""
        indicators = []
        
        caps_ratio = sum(1 for c in text if c.isupper()) / len(text) if text else 0
        if caps_ratio > self.spam_indicators["excessive_caps_ratio"]:
            indicators.append("excessive_caps")
            
        exclamation_count = text.count('!') + text.count('!')
        if exclamation_count > self.spam_indicators["exclamation_count"]:
            indicators.append("excessive_exclamation")
            
        import re
        url_count = len(re.findall(r'https?://\S+', text))
        if url_count > self.spam_indicators["url_count"]:
            indicators.append("excessive_urls")
            
        repeated_punct = len(re.findall(r'(.)\1{4,}', text))
        if repeated_punct > self.spam_indicators["repeated_punctuation"]:
            indicators.append("repeated_characters")
            
        return indicators
        
    def apply_safety_rules(self, text):
        """应用安全规则"""
        results = {
            "passed": True,
            "risk_level": "safe",
            "issues": []
        }
        
        # 检查联系方式
        contact_issues = self.check_contact_info(text)
        if contact_issues:
            results["passed"] = False
            results["risk_level"] = "high"
            results["issues"].extend(contact_issues)
            
        # 检查垃圾指标
        spam_indicators = self.check_spam_indicators(text)
        if spam_indicators:
            results["issues"].append({
                "type": "spam_indicators",
                "indicators": spam_indicators
            })
            if len(spam_indicators) >= 2:
                results["risk_level"] = "medium"
                
        return results

1.3 规则引擎执行

class RuleEngine:
    """规则引擎"""
    
    def __init__(self):
        self.rule_sets = {}
        self.execution_order = []
        
    def register_rule_set(self, name, rule_set, priority=0):
        """注册规则集"""
        self.rule_sets[name] = rule_set
        self.execution_order.append((priority, name))
        self.execution_order.sort(key=lambda x: x[0], reverse=True)
        
    def execute(self, text, context=None):
        """
        执行规则引擎
        
        优化策略:
        - 短路执行:早期规则拒绝则跳过后续
        - 并行执行:无依赖规则可并行处理
        - 结果缓存:避免重复计算
        """
        context = context or {}
        final_decision = "accept"
        
        for priority, rule_name in self.execution_order:
            rule_set = self.rule_sets[rule_name]
            
            # 短路策略
            if final_decision == "reject" and rule_set.get("short_circuit"):
                continue
                
            result = rule_set.apply(text, context)
            
            if result["action"] == "reject":
                final_decision = "reject"
            elif result["action"] == "flag" and final_decision == "accept":
                final_decision = "flag"
                
            # 累积上下文
            context[f"{rule_name}_result"] = result
            
        return {
            "decision": final_decision,
            "context": context,
            "processing_time_ms": 0  # 实际应记录耗时
        }

二、基于模型的清洗

2.1 模型清洗概述

随着深度学习技术的发展,基于模型的文本质量评估已成为数据清洗的重要组成部分。相比规则方法,模型方法能够捕捉更复杂的语言模式和语义特征。

模型选择策略

场景推荐模型优势劣势
质量评分Reward Model / DeBERTa语义理解能力强需要标注数据训练
毒性检测Detoxify / Perspective API专业领域优化可能产生误判
重复检测SeqModel / Embedding语义重复识别计算成本较高
事实核查FActScore / KILT百科知识融合仅限已知事实

2.2 质量评分模型

Reward Model用于质量评估

class QualityScorer:
    """基于Reward Model的质量评分器"""
    
    def __init__(self, model_path, device="cuda"):
        from transformers import AutoModelForSequenceClassification, AutoTokenizer
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_path
        ).to(device)
        self.device = device
        
    def score(self, instruction, response):
        """
        对问答对进行质量评分
        
        Returns:
            score: 0-1之间的质量分数
            reasoning: 评分理由
        """
        # 构建输入格式
        text = f"Instruction: {instruction}\nResponse: {response}"
        inputs = self.tokenizer(
            text, 
            return_tensors="pt", 
            truncation=True, 
            max_length=2048
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            score = torch.sigmoid(outputs.logits).item()
            
        return {
            "score": score,
            "quality_level": self._classify_quality(score),
            "confidence": self._calculate_confidence(outputs.logits)
        }
        
    def batch_score(self, samples, batch_size=32):
        """批量评分"""
        scores = []
        
        for i in range(0, len(samples), batch_size):
            batch = samples[i:i+batch_size]
            
            texts = [
                f"Instruction: {s['instruction']}\nResponse: {s['response']}"
                for s in batch
            ]
            
            inputs = self.tokenizer(
                texts,
                return_tensors="pt",
                truncation=True,
                max_length=2048,
                padding=True
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                batch_scores = torch.sigmoid(outputs.logits).squeeze(-1)
                
            scores.extend(batch_scores.cpu().tolist())
            
        return scores
        
    def _classify_quality(self, score):
        """质量等级分类"""
        if score >= 0.9:
            return "excellent"
        elif score >= 0.7:
            return "good"
        elif score >= 0.5:
            return "acceptable"
        else:
            return "poor"
            
    def _calculate_confidence(self, logits):
        """计算评分置信度"""
        probs = torch.softmax(logits, dim=-1)
        max_prob = probs.max().item()
        return max_prob

多维度质量评估

class MultiDimensionalQualityAnalyzer:
    """多维度质量分析器"""
    
    def __init__(self):
        self.dimensions = {
            "relevance": RelevanceScorer(),
            "helpfulness": HelpfulnessScorer(),
            "safety": SafetyClassifier(),
            "coherence": CoherenceAnalyzer(),
            "fluency": FluencyScorer()
        }
        
    def analyze(self, instruction, response):
        """
        多维度质量分析
        
        返回各维度的详细评分和建议
        """
        results = {}
        
        for dim_name, scorer in self.dimensions.items():
            dim_result = scorer.evaluate(instruction, response)
            results[dim_name] = dim_result
            
        # 综合评分
        weights = {"relevance": 0.25, "helpfulness": 0.30, 
                  "safety": 0.20, "coherence": 0.15, "fluency": 0.10}
        
        overall_score = sum(
            results[dim]["score"] * weight
            for dim, weight in weights.items()
        )
        
        # 生成改进建议
        suggestions = self._generate_suggestions(results)
        
        return {
            "overall_score": overall_score,
            "dimensions": results,
            "suggestions": suggestions,
            "pass_threshold": overall_score >= 0.6
        }
        
    def _generate_suggestions(self, results):
        """生成改进建议"""
        suggestions = []
        
        for dim, result in results.items():
            if result["score"] < 0.6:
                suggestions.append({
                    "dimension": dim,
                    "issue": result.get("issue", "需要改进"),
                    "suggestion": result.get("suggestion", "建议重写")
                })
                
        return suggestions

2.3 毒性内容检测

class ToxicityDetector:
    """毒性内容检测器"""
    
    def __init__(self, model_type="detoxify"):
        if model_type == "detoxify":
            from detoxify import Detoxify
            self.detector = Detoxify('unbiased')
        else:
            # 使用Perspective API
            self.api_key = os.getenv("PERSPECTIVE_API_KEY")
            
    def detect(self, text):
        """
        检测毒性内容
        
        Returns:
            toxicity_scores: 各毒性类别的评分
            overall_toxic: 总体毒性评分
            categories: ["toxicity", "severe_toxicity", "obscene", 
                       "threat", "insult", "identity_attack"]
        """
        if hasattr(self, 'detector'):
            scores = self.detector.predict(text)
            
            # 标准化评分
            normalized_scores = {
                category: float(score)
                for category, score in scores.items()
            }
            
            overall = max(normalized_scores.values())
            
            return {
                "scores": normalized_scores,
                "overall_toxicity": overall,
                "is_toxic": overall > 0.5,
                "toxic_categories": [
                    cat for cat, score in normalized_scores.items()
                    if score > 0.5
                ],
                "severity": self._classify_severity(overall)
            }
        else:
            # Perspective API 实现
            return self._perspective_api_check(text)
            
    def batch_detect(self, texts, threshold=0.5):
        """批量检测"""
        results = []
        
        for text in texts:
            result = self.detect(text)
            result["text"] = text[:100]  # 截断保存
            result["action"] = "reject" if result["is_toxic"] else "accept"
            results.append(result)
            
        stats = {
            "total": len(results),
            "toxic_count": sum(1 for r in results if r["is_toxic"]),
            "toxic_rate": sum(1 for r in results if r["is_toxic"]) / len(results),
            "by_category": self._aggregate_by_category(results)
        }
        
        return {"results": results, "statistics": stats}
        
    def _classify_severity(self, score):
        """毒性严重程度分类"""
        if score > 0.8:
            return "critical"
        elif score > 0.6:
            return "high"
        elif score > 0.3:
            return "medium"
        else:
            return "low"
            
    def _aggregate_by_category(self, results):
        """按类别聚合统计"""
        category_counts = {}
        
        for result in results:
            for cat in result.get("toxic_categories", []):
                category_counts[cat] = category_counts.get(cat, 0) + 1
                
        return category_counts

三、去重(Deduplication)技术

3.1 去重的重要性

在训练大规模语言模型时,去重是防止数据污染和模型记忆的关键步骤。研究表明,即使是很小比例的重复数据也可能导致:

  • 评估污染:测试集与训练集重叠导致指标虚高
  • 记忆外推:模型过度记忆训练数据,降低泛化能力
  • 训练不稳定:极端重复数据导致梯度异常

3.2 精确去重

字符串级别去重

class ExactDeduplicator:
    """精确去重器"""
    
    def __init__(self):
        self.seen_hashes = set()
        self.seen_exact = set()
        
    def exact_match_dedup(self, texts, hash_func="md5"):
        """
        精确字符串匹配去重
        
        使用哈希表实现O(1)查找
        """
        import hashlib
        
        unique_texts = []
        duplicate_count = 0
        
        for text in texts:
            if hash_func == "md5":
                text_hash = hashlib.md5(text.encode()).hexdigest()
            elif hash_func == "sha256":
                text_hash = hashlib.sha256(text.encode()).hexdigest()
                
            if text_hash not in self.seen_hashes:
                self.seen_hashes.add(text_hash)
                unique_texts.append(text)
            else:
                duplicate_count += 1
                
        return {
            "unique_texts": unique_texts,
            "total": len(texts),
            "duplicates_removed": duplicate_count,
            "dedup_rate": duplicate_count / len(texts) if texts else 0
        }
        
    def normalize_and_dedup(self, texts):
        """
        规范化后去重(去除格式差异)
        """
        normalized = []
        
        for text in texts:
            norm_text = self._normalize_text(text)
            if norm_text not in self.seen_exact:
                self.seen_exact.add(norm_text)
                normalized.append(text)
                
        return normalized
        
    def _normalize_text(self, text):
        """文本规范化"""
        import unicodedata
        import re
        
        # Unicode规范化
        text = unicodedata.normalize('NFKC', text)
        
        # 去除多余空白
        text = ' '.join(text.split())
        
        # 转小写
        text = text.lower()
        
        return text

3.3 近似去重

MinHash去重

class MinHashDeduplicator:
    """MinHash近似去重"""
    
    def __init__(self, num_hashes=128, threshold=0.8):
        self.num_hashes = num_hashes
        self.threshold = threshold
        self.minhashes = []
        self.texts = []
        
    def _compute_minhash(self, text, n=5):
        """
        计算文本的MinHash签名
        
        使用n-gram分词提高鲁棒性
        """
        import hashlib
        
        # 生成n-gram
        words = text.split()
        ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
        
        if not ngrams:
            ngrams = [tuple(text[i:i+n]) for i in range(len(text)-n+1)]
            
        # 初始化MinHash数组
        minhash = [float('inf')] * self.num_hashes
        
        for ngram in ngrams:
            # 用多个哈希函数模拟多个MinHash
            for i in range(self.num_hashes):
                hash_input = f"{ngram}_{i}".encode()
                hash_value = int(hashlib.sha256(hash_input).hexdigest(), 16)
                minhash[i] = min(minhash[i], hash_value)
                
        return tuple(minhash)
        
    def _jaccard_similarity(self, hash1, hash2):
        """计算Jaccard相似度"""
        return sum(h1 == h2 for h1, h2 in zip(hash1, hash2)) / self.num_hashes
        
    def fit(self, texts):
        """构建MinHash索引"""
        self.texts = []
        self.minhashes = []
        
        for text in texts:
            self.texts.append(text)
            self.minhashes.append(self._compute_minhash(text))
            
    def find_duplicates(self):
        """查找近似重复项"""
        duplicates = []
        
        for i in range(len(self.minhashes)):
            for j in range(i + 1, len(self.minhashes)):
                sim = self._jaccard_similarity(
                    self.minhashes[i], 
                    self.minhashes[j]
                )
                
                if sim >= self.threshold:
                    duplicates.append({
                        "id1": i,
                        "id2": j,
                        "similarity": sim,
                        "text1": self.texts[i][:100],
                        "text2": self.texts[j][:100]
                    })
                    
        return duplicates
        
    def deduplicate(self):
        """
        去重并返回唯一文本
        
        使用并查集(Union-Find)合并相似簇
        """
        parent = list(range(len(self.minhashes)))
        
        def find(x):
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]
            
        def union(x, y):
            px, py = find(x), find(y)
            if px != py:
                parent[px] = py
                
        # 构建相似性簇
        for i in range(len(self.minhashes)):
            for j in range(i + 1, len(self.minhashes)):
                if self._jaccard_similarity(
                    self.minhashes[i], 
                    self.minhashes[j]
                ) >= self.threshold:
                    union(i, j)
                    
        # 从每个簇中选择代表
        clusters = {}
        for i in range(len(self.texts)):
            root = find(i)
            if root not in clusters:
                clusters[root] = i
                
        unique_indices = list(clusters.values())
        
        return {
            "unique_texts": [self.texts[i] for i in unique_indices],
            "clusters": len(clusters),
            "removed_count": len(self.texts) - len(unique_indices)
        }

SimHash去重

class SimHashDeduplicator:
    """SimHash去重器(适合长文本)"""
    
    def __init__(self, num_bits=64, threshold=3):
        self.num_bits = num_bits
        self.threshold = threshold  # 海明距离阈值
        self.hash_table = {}
        
    def _simhash(self, text):
        """计算SimHash"""
        import hashlib
        
        # 分词
        words = text.split()
        
        # 初始化向量
        v = [0] * self.num_bits
        
        # 加权哈希
        for word in words:
            hash_value = int(hashlib.md5(word.encode()).hexdigest(), 16)
            
            for i in range(self.num_bits):
                bit = (hash_value >> i) & 1
                weight = 1.0 / (1 + len(words))  # 简单加权
                v[i] += weight if bit else -weight
                
        # 生成指纹
        fingerprint = 0
        for i in range(self.num_bits):
            if v[i] > 0:
                fingerprint |= (1 << i)
                
        return fingerprint
        
    def _hamming_distance(self, hash1, hash2):
        """计算海明距离"""
        xor = hash1 ^ hash2
        return bin(xor).count('1')
        
    def add(self, text):
        """添加文本并检查是否重复"""
        fingerprint = self._simhash(text)
        
        # 检查近似重复
        near_duplicates = []
        for stored_hash, stored_text in self.hash_table.items():
            if self._hamming_distance(fingerprint, stored_hash) <= self.threshold:
                near_duplicates.append(stored_text)
                
        if near_duplicates:
            return {
                "is_duplicate": True,
                "duplicates": near_duplicates,
                "fingerprint": fingerprint
            }
        else:
            self.hash_table[fingerprint] = text[:100]
            return {
                "is_duplicate": False,
                "fingerprint": fingerprint
            }

四、质量评分与过滤

4.1 多维度质量评分

class QualityFilter:
    """综合质量过滤器"""
    
    def __init__(self):
        self.scorers = {
            "length": LengthScorer(),
            "language": LanguageScorer(),
            "repetition": RepetitionScorer(),
            "nlp": NLPCohesionScorer()
        }
        self.weights = {
            "length": 0.15,
            "language": 0.25,
            "repetition": 0.20,
            "nlp": 0.40
        }
        
    def score(self, text):
        """综合评分"""
        scores = {}
        
        for name, scorer in self.scorers.items():
            scores[name] = scorer.score(text)
            
        weighted_score = sum(
            scores[name] * self.weights[name]
            for name in scores
        )
        
        return {
            "overall": weighted_score,
            "dimensions": scores,
            "passed": weighted_score >= 0.5
        }
        
    def filter_dataset(self, texts, threshold=0.5):
        """过滤数据集"""
        results = []
        
        for i, text in enumerate(texts):
            score = self.score(text)
            
            results.append({
                "text": text,
                "index": i,
                "score": score,
                "action": "accept" if score["passed"] else "reject"
            })
            
        accepted = [r for r in results if r["action"] == "accept"]
        rejected = [r for r in results if r["action"] == "reject"]
        
        return {
            "accepted": accepted,
            "rejected": rejected,
            "acceptance_rate": len(accepted) / len(results) if results else 0
        }

4.2 自适应阈值

class AdaptiveThresholdFilter:
    """自适应阈值过滤器"""
    
    def __init__(self, initial_threshold=0.5):
        self.threshold = initial_threshold
        self.quality_distribution = []
        
    def update_threshold(self, scored_samples):
        """
        基于样本分布更新阈值
        
        使用分位数方法确保保留特定比例的高质量数据
        """
        scores = [s["score"] for s in scored_samples]
        
        # 计算分布
        self.quality_distribution = {
            "mean": np.mean(scores),
            "std": np.std(scores),
            "q25": np.percentile(scores, 25),
            "q50": np.percentile(scores, 50),
            "q75": np.percentile(scores, 75),
            "q90": np.percentile(scores, 90),
            "q95": np.percentile(scores, 95)
        }
        
        # 建议阈值(保留top 70%)
        self.threshold = np.percentile(scores, 30)
        
        return {
            "new_threshold": self.threshold,
            "distribution": self.quality_distribution,
            "recommended_top_percent": 70
        }

五、数据分布分析

5.1 分布监控指标

class DataDistributionAnalyzer:
    """数据分布分析器"""
    
    def __init__(self):
        self.metrics = {}
        
    def analyze_text_length(self, texts):
        """文本长度分布分析"""
        lengths = [len(t) for t in texts]
        
        return {
            "mean": np.mean(lengths),
            "median": np.median(lengths),
            "std": np.std(lengths),
            "min": np.min(lengths),
            "max": np.max(lengths),
            "percentiles": {
                "p25": np.percentile(lengths, 25),
                "p75": np.percentile(lengths, 75),
                "p95": np.percentile(lengths, 95)
            },
            "distribution": self._create_histogram(lengths)
        }
        
    def analyze_topic_distribution(self, texts):
        """主题分布分析(使用嵌入聚类)"""
        from sklearn.cluster import KMeans
        from sklearn.decomposition import PCA
        
        # 生成文本嵌入
        embeddings = self._get_embeddings(texts[:10000])
        
        # PCA降维
        pca = PCA(n_components=10)
        reduced = pca.fit_transform(embeddings)
        
        # KMeans聚类
        n_clusters = min(20, len(texts) // 100)
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(reduced)
        
        return {
            "n_clusters": n_clusters,
            "cluster_sizes": Counter(clusters),
            "variance_explained": sum(pca.explained_variance_ratio_)
        }
        
    def check_distribution_shift(self, old_texts, new_texts):
        """检测数据分布偏移"""
        old_lengths = [len(t) for t in old_texts]
        new_lengths = [len(t) for t in new_texts]
        
        from scipy import stats
        
        ks_stat, p_value = stats.ks_2samp(old_lengths, new_lengths)
        
        return {
            "ks_statistic": ks_stat,
            "p_value": p_value,
            "significant_shift": p_value < 0.05,
            "recommendation": "需要重新平衡" if p_value < 0.05 else "分布一致"
        }

相关文档