Skip to content

第54天:微调数据准备

学习目标

  • 掌握数据收集方法
  • 学习数据清洗技术
  • 理解数据格式化
  • 掌握数据增强方法
  • 了解数据质量评估

数据收集

数据源识别

python
from typing import List, Dict

class DataSourceIdentifier:
    def __init__(self):
        self.data_sources = {
            "public_datasets": [
                "Hugging Face Datasets",
                "Common Crawl",
                "Wikipedia",
                "arXiv",
                "PubMed"
            ],
            "company_data": [
                "Internal documents",
                "Customer support logs",
                "Product documentation",
                "Sales transcripts",
                "Internal wikis"
            ],
            "user_generated": [
                "Social media",
                "Forums",
                "Reviews",
                "Q&A sites"
            ],
            "synthetic": [
                "LLM-generated",
                "Template-based",
                "Rule-based"
            ]
        }
    
    def identify_sources(self, domain: str, 
                      task: str) -> List[str]:
        relevant_sources = []
        
        if domain == "medical":
            relevant_sources.extend([
                "PubMed",
                "Medical journals",
                "Clinical guidelines"
            ])
        elif domain == "legal":
            relevant_sources.extend([
                "Legal cases",
                "Law reviews",
                "Regulations"
            ])
        elif domain == "finance":
            relevant_sources.extend([
                "Financial reports",
                "Market data",
                "Economic indicators"
            ])
        
        if task == "question_answering":
            relevant_sources.extend([
                "Q&A datasets",
                "FAQs",
                "Knowledge bases"
            ])
        elif task == "summarization":
            relevant_sources.extend([
                "Articles",
                "Reports",
                "Documents"
            ])
        
        return relevant_sources

数据收集器

python
import requests
from pathlib import Path

class DataCollector:
    def __init__(self, config: Dict):
        self.config = config
        self.collected_data = []
    
    def collect_from_huggingface(self, dataset_name: str) -> List[Dict]:
        try:
            from datasets import load_dataset
            
            dataset = load_dataset(dataset_name)
            
            collected = []
            for split in dataset.keys():
                for item in dataset[split]:
                    collected.append({
                        "source": "huggingface",
                        "dataset": dataset_name,
                        "split": split,
                        "data": item
                    })
            
            self.collected_data.extend(collected)
            
            return collected
        except ImportError:
            raise ImportError("Install datasets: pip install datasets")
    
    def collect_from_api(self, api_url: str, 
                         params: Dict = None) -> List[Dict]:
        response = requests.get(api_url, params=params)
        
        if response.status_code == 200:
            data = response.json()
            
            collected = []
            for item in data:
                collected.append({
                    "source": "api",
                    "url": api_url,
                    "data": item
                })
            
            self.collected_data.extend(collected)
            
            return collected
        else:
            raise Exception(f"API request failed: {response.status_code}")
    
    def collect_from_files(self, directory: str, 
                         file_patterns: List[str] = None) -> List[Dict]:
        path = Path(directory)
        
        if not path.exists():
            raise FileNotFoundError(f"Directory not found: {directory}")
        
        if file_patterns is None:
            file_patterns = ["*.txt", "*.json", "*.csv"]
        
        collected = []
        
        for pattern in file_patterns:
            for file_path in path.rglob(pattern):
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        content = f.read()
                    
                    collected.append({
                        "source": "file",
                        "path": str(file_path),
                        "data": {"content": content}
                    })
                except Exception as e:
                    print(f"Error reading {file_path}: {e}")
        
        self.collected_data.extend(collected)
        
        return collected
    
    def collect_from_web(self, urls: List[str]) -> List[Dict]:
        collected = []
        
        for url in urls:
            try:
                response = requests.get(url)
                
                if response.status_code == 200:
                    from bs4 import BeautifulSoup
                    soup = BeautifulSoup(response.text, 'html.parser')
                    content = soup.get_text()
                    
                    collected.append({
                        "source": "web",
                        "url": url,
                        "data": {"content": content}
                    })
            except Exception as e:
                print(f"Error fetching {url}: {e}")
        
        self.collected_data.extend(collected)
        
        return collected

数据合成

python
class DataSynthesizer:
    def __init__(self, llm):
        self.llm = llm
    
    def synthesize_qa_pairs(self, topic: str, 
                           n_pairs: int = 10) -> List[Dict]:
        prompt = f"""
        Generate {n_pairs} question-answer pairs about {topic}.
        
        Format each pair as:
        Q: [question]
        A: [answer]
        
        Make questions diverse and challenging.
        """
        
        response = self.llm.generate(prompt)
        
        return self._parse_qa_pairs(response)
    
    def synthesize_instruction_data(self, task: str, 
                                  n_examples: int = 10) -> List[Dict]:
        prompt = f"""
        Generate {n_examples} examples for the task: {task}
        
        Format each example as:
        Instruction: [instruction]
        Input: [input]
        Output: [output]
        
        Make examples diverse and realistic.
        """
        
        response = self.llm.generate(prompt)
        
        return self._parse_instruction_data(response)
    
    def synthesize_conversations(self, scenario: str, 
                               n_conversations: int = 5) -> List[Dict]:
        prompt = f"""
        Generate {n_conversations} conversations for the scenario: {scenario}
        
        Format each conversation as:
        User: [user message]
        Assistant: [assistant response]
        
        Make conversations natural and helpful.
        """
        
        response = self.llm.generate(prompt)
        
        return self._parse_conversations(response)
    
    def _parse_qa_pairs(self, response: str) -> List[Dict]:
        pairs = []
        lines = response.split('\n')
        
        current_pair = {}
        for line in lines:
            line = line.strip()
            if line.startswith('Q:'):
                current_pair = {"question": line[2:].strip()}
            elif line.startswith('A:'):
                current_pair["answer"] = line[2:].strip()
                pairs.append(current_pair)
        
        return pairs
    
    def _parse_instruction_data(self, response: str) -> List[Dict]:
        examples = []
        lines = response.split('\n')
        
        current_example = {}
        for line in lines:
            line = line.strip()
            if line.startswith('Instruction:'):
                current_example = {"instruction": line[12:].strip()}
            elif line.startswith('Input:'):
                current_example["input"] = line[6:].strip()
            elif line.startswith('Output:'):
                current_example["output"] = line[7:].strip()
                examples.append(current_example)
        
        return examples
    
    def _parse_conversations(self, response: str) -> List[Dict]:
        conversations = []
        current_conversation = []
        
        lines = response.split('\n')
        for line in lines:
            line = line.strip()
            if line.startswith('User:'):
                current_conversation.append({
                    "role": "user",
                    "content": line[5:].strip()
                })
            elif line.startswith('Assistant:'):
                current_conversation.append({
                    "role": "assistant",
                    "content": line[11:].strip()
                })
        
        if current_conversation:
            conversations.append({"messages": current_conversation})
        
        return conversations

数据清洗

文本清洗

python
import re

class TextCleaner:
    def __init__(self, config: Dict = None):
        self.config = config or {}
    
    def clean(self, text: str) -> str:
        text = self._remove_html_tags(text)
        text = self._remove_urls(text)
        text = self._remove_emails(text)
        text = self._remove_phone_numbers(text)
        text = self._normalize_whitespace(text)
        text = self._remove_special_chars(text)
        
        return text
    
    def _remove_html_tags(self, text: str) -> str:
        clean = re.compile('<.*?>')
        return re.sub(clean, '', text)
    
    def _remove_urls(self, text: str) -> str:
        url_pattern = re.compile(r'http\S+|www\.\S+')
        return url_pattern.sub('', text)
    
    def _remove_emails(self, text: str) -> str:
        email_pattern = re.compile(r'\S+@\S+')
        return email_pattern.sub('', text)
    
    def _remove_phone_numbers(self, text: str) -> str:
        phone_pattern = re.compile(r'\d{3}[-.\s]?\d{3}[-.\s]?\d{4}')
        return phone_pattern.sub('', text)
    
    def _normalize_whitespace(self, text: str) -> str:
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        return text
    
    def _remove_special_chars(self, text: str) -> str:
        special_chars = self.config.get("special_chars_to_remove", [])
        
        for char in special_chars:
            text = text.replace(char, '')
        
        return text

数据去重

python
class DataDeduplicator:
    def __init__(self):
        pass
    
    def deduplicate_by_hash(self, data: List[Dict]) -> List[Dict]:
        seen_hashes = set()
        deduplicated = []
        
        for item in data:
            text = str(item.get("data", item))
            text_hash = hash(text)
            
            if text_hash not in seen_hashes:
                seen_hashes.add(text_hash)
                deduplicated.append(item)
        
        return deduplicated
    
    def deduplicate_by_similarity(self, data: List[Dict], 
                                threshold: float = 0.95) -> List[Dict]:
        deduplicated = []
        
        for item in data:
            is_duplicate = False
            
            for existing in deduplicated:
                similarity = self._calculate_similarity(
                    str(item.get("data", item)),
                    str(existing.get("data", existing))
                )
                
                if similarity > threshold:
                    is_duplicate = True
                    break
            
            if not is_duplicate:
                deduplicated.append(item)
        
        return deduplicated
    
    def _calculate_similarity(self, text1: str, 
                             text2: str) -> float:
        words1 = set(text1.lower().split())
        words2 = set(text2.lower().split())
        
        intersection = words1 & words2
        union = words1 | words2
        
        if not union:
            return 0.0
        
        return len(intersection) / len(union)

数据过滤

python
class DataFilter:
    def __init__(self, config: Dict = None):
        self.config = config or {}
    
    def filter_by_length(self, data: List[Dict], 
                         min_length: int = 10, 
                         max_length: int = 10000) -> List[Dict]:
        filtered = []
        
        for item in data:
            text = str(item.get("data", item))
            text_length = len(text.split())
            
            if min_length <= text_length <= max_length:
                filtered.append(item)
        
        return filtered
    
    def filter_by_language(self, data: List[Dict], 
                         language: str = "en") -> List[Dict]:
        try:
            from langdetect import detect
        except ImportError:
            raise ImportError("Install langdetect: pip install langdetect")
        
        filtered = []
        
        for item in data:
            text = str(item.get("data", item))
            
            try:
                detected_lang = detect(text)
                
                if detected_lang == language:
                    filtered.append(item)
            except:
                pass
        
        return filtered
    
    def filter_by_quality(self, data: List[Dict], 
                         min_quality_score: float = 0.5) -> List[Dict]:
        filtered = []
        
        for item in data:
            quality_score = self._assess_quality(item)
            
            if quality_score >= min_quality_score:
                filtered.append(item)
        
        return filtered
    
    def _assess_quality(self, item: Dict) -> float:
        text = str(item.get("data", item))
        
        scores = []
        
        scores.append(self._check_length_quality(text))
        scores.append(self._check_structure_quality(text))
        scores.append(self._check_content_quality(text))
        
        return sum(scores) / len(scores)
    
    def _check_length_quality(self, text: str) -> float:
        length = len(text.split())
        
        if 50 <= length <= 500:
            return 1.0
        elif 10 <= length < 50 or 500 < length <= 1000:
            return 0.7
        else:
            return 0.3
    
    def _check_structure_quality(self, text: str) -> float:
        if not text:
            return 0.0
        
        has_sentences = any(char in text for char in '.!?')
        has_paragraphs = '\n\n' in text
        
        score = 0.5
        if has_sentences:
            score += 0.3
        if has_paragraphs:
            score += 0.2
        
        return score
    
    def _check_content_quality(self, text: str) -> float:
        words = text.split()
        
        if not words:
            return 0.0
        
        avg_word_length = sum(len(word) for word in words) / len(words)
        
        if 3 <= avg_word_length <= 6:
            return 1.0
        elif 2 <= avg_word_length < 3 or 6 < avg_word_length <= 8:
            return 0.7
        else:
            return 0.3

数据格式化

指令格式化

python
class InstructionFormatter:
    def __init__(self, template: str = None):
        self.template = template or self._default_template()
    
    def _default_template(self) -> str:
        return """Instruction: {instruction}
Input: {input}
Output: {output}"""
    
    def format(self, data: Dict) -> str:
        instruction = data.get("instruction", "")
        input_text = data.get("input", "")
        output_text = data.get("output", "")
        
        formatted = self.template.format(
            instruction=instruction,
            input=input_text,
            output=output_text
        )
        
        return formatted
    
    def format_batch(self, data_list: List[Dict]) -> List[str]:
        return [self.format(data) for data in data_list]

对话格式化

python
class ConversationFormatter:
    def __init__(self, template: str = None):
        self.template = template or self._default_template()
    
    def _default_template(self) -> str:
        return """{messages}"""
    
    def format(self, conversation: Dict) -> str:
        messages = conversation.get("messages", [])
        
        formatted_messages = []
        for message in messages:
            role = message.get("role", "user")
            content = message.get("content", "")
            
            formatted_messages.append(f"{role}: {content}")
        
        formatted = self.template.format(
            messages="\n".join(formatted_messages)
        )
        
        return formatted
    
    def format_batch(self, conversations: List[Dict]) -> List[str]:
        return [self.format(conv) for conv in conversations]

QA格式化

python
class QAFormatter:
    def __init__(self, template: str = None):
        self.template = template or self._default_template()
    
    def _default_template(self) -> str:
        return """Question: {question}
Answer: {answer}"""
    
    def format(self, qa_pair: Dict) -> str:
        question = qa_pair.get("question", "")
        answer = qa_pair.get("answer", "")
        
        formatted = self.template.format(
            question=question,
            answer=answer
        )
        
        return formatted
    
    def format_batch(self, qa_pairs: List[Dict]) -> List[str]:
        return [self.format(qa) for qa in qa_pairs]

数据增强

同义词替换

python
import random

class SynonymAugmenter:
    def __init__(self):
        self.synonym_dict = self._build_synonym_dict()
    
    def _build_synonym_dict(self) -> Dict:
        return {
            "good": ["excellent", "great", "wonderful", "amazing"],
            "bad": ["terrible", "awful", "poor", "horrible"],
            "big": ["large", "huge", "enormous", "massive"],
            "small": ["tiny", "little", "mini", "compact"],
            "fast": ["quick", "rapid", "swift", "speedy"],
            "slow": ["sluggish", "leisurely", "unhurried", "gradual"]
        }
    
    def augment(self, text: str, 
               n_augmentations: int = 1) -> List[str]:
        augmented_texts = [text]
        
        for _ in range(n_augmentations):
            words = text.split()
            augmented_words = []
            
            for word in words:
                if word.lower() in self.synonym_dict:
                    synonyms = self.synonym_dict[word.lower()]
                    augmented_word = random.choice(synonyms)
                    augmented_words.append(augmented_word)
                else:
                    augmented_words.append(word)
            
            augmented_text = " ".join(augmented_words)
            augmented_texts.append(augmented_text)
        
        return augmented_texts

回译增强

python
class BackTranslationAugmenter:
    def __init__(self, translator):
        self.translator = translator
    
    def augment(self, text: str, 
               intermediate_lang: str = "fr") -> List[str]:
        translated = self.translator.translate(text, target_lang=intermediate_lang)
        back_translated = self.translator.translate(translated, target_lang="en")
        
        return [text, back_translated]

随机插入

python
class RandomInsertionAugmenter:
    def __init__(self, insertion_words: List[str] = None):
        self.insertion_words = insertion_words or [
            "actually", "basically", "essentially", "really",
            "very", "quite", "rather", "somewhat"
        ]
    
    def augment(self, text: str, 
               n_augmentations: int = 1) -> List[str]:
        augmented_texts = [text]
        
        for _ in range(n_augmentations):
            words = text.split()
            augmented_words = words.copy()
            
            n_insertions = random.randint(1, 3)
            
            for _ in range(n_insertions):
                insert_pos = random.randint(0, len(augmented_words))
                insert_word = random.choice(self.insertion_words)
                
                augmented_words.insert(insert_pos, insert_word)
            
            augmented_text = " ".join(augmented_words)
            augmented_texts.append(augmented_text)
        
        return augmented_texts

数据质量评估

质量指标

python
class DataQualityAssessor:
    def __init__(self):
        pass
    
    def assess(self, data: List[Dict]) -> Dict:
        metrics = {}
        
        metrics["length_distribution"] = self._assess_length_distribution(data)
        metrics["language_distribution"] = self._assess_language_distribution(data)
        metrics["quality_scores"] = self._assess_quality_scores(data)
        metrics["diversity"] = self._assess_diversity(data)
        
        return metrics
    
    def _assess_length_distribution(self, data: List[Dict]) -> Dict:
        lengths = [len(str(item.get("data", item)).split()) for item in data]
        
        return {
            "mean": sum(lengths) / len(lengths),
            "median": sorted(lengths)[len(lengths) // 2],
            "min": min(lengths),
            "max": max(lengths),
            "std": self._std_dev(lengths)
        }
    
    def _assess_language_distribution(self, data: List[Dict]) -> Dict:
        try:
            from langdetect import detect
        except ImportError:
            return {"error": "langdetect not installed"}
        
        languages = {}
        
        for item in data:
            text = str(item.get("data", item))
            
            try:
                lang = detect(text)
                languages[lang] = languages.get(lang, 0) + 1
            except:
                pass
        
        return languages
    
    def _assess_quality_scores(self, data: List[Dict]) -> Dict:
        scores = []
        
        for item in data:
            score = self._assess_single_quality(item)
            scores.append(score)
        
        return {
            "mean": sum(scores) / len(scores),
            "min": min(scores),
            "max": max(scores)
        }
    
    def _assess_single_quality(self, item: Dict) -> float:
        text = str(item.get("data", item))
        
        score = 0.0
        
        if len(text.split()) >= 10:
            score += 0.3
        
        if any(char in text for char in '.!?'):
            score += 0.3
        
        if len(text.split('\n')) > 1:
            score += 0.2
        
        if len(text) > 100:
            score += 0.2
        
        return score
    
    def _assess_diversity(self, data: List[Dict]) -> Dict:
        all_words = set()
        
        for item in data:
            text = str(item.get("data", item))
            words = set(text.lower().split())
            all_words.update(words)
        
        return {
            "unique_words": len(all_words),
            "total_samples": len(data),
            "avg_unique_per_sample": len(all_words) / len(data)
        }
    
    def _std_dev(self, values: List[float]) -> float:
        if len(values) < 2:
            return 0.0
        
        mean = sum(values) / len(values)
        variance = sum((x - mean) ** 2 for x in values) / len(values)
        
        return variance ** 0.5

实践练习

练习1:实现完整的数据准备流程

python
class DataPreparationPipeline:
    def __init__(self, config):
        self.config = config
        self.collector = DataCollector(config)
        self.cleaner = TextCleaner(config)
        self.deduplicator = DataDeduplicator()
        self.filter = DataFilter(config)
        self.formatter = InstructionFormatter()
    
    def prepare(self, sources: List[str]) -> List[str]:
        collected = []
        
        for source in sources:
            if source.startswith("http"):
                data = self.collector.collect_from_api(source)
            else:
                data = self.collector.collect_from_files(source)
            
            collected.extend(data)
        
        cleaned = []
        for item in collected:
            text = str(item.get("data", item))
            cleaned_text = self.cleaner.clean(text)
            item["data"] = {"content": cleaned_text}
            cleaned.append(item)
        
        deduplicated = self.deduplicator.deduplicate_by_hash(cleaned)
        filtered = self.filter.filter_by_length(deduplicated)
        
        formatted = self.formatter.format_batch(filtered)
        
        return formatted

总结

本节我们学习了微调数据准备:

  1. 数据收集方法(数据源识别、收集器、合成)
  2. 数据清洗技术(文本清洗、去重、过滤)
  3. 数据格式化(指令、对话、QA)
  4. 数据增强方法(同义词、回译、随机插入)
  5. 数据质量评估(长度、语言、质量、多样性)

高质量的数据是微调成功的关键。

参考资源