Appearance
第54天:微调数据准备
学习目标
- 掌握数据收集方法
- 学习数据清洗技术
- 理解数据格式化
- 掌握数据增强方法
- 了解数据质量评估
数据收集
数据源识别
python
from typing import List, Dict
class DataSourceIdentifier:
def __init__(self):
self.data_sources = {
"public_datasets": [
"Hugging Face Datasets",
"Common Crawl",
"Wikipedia",
"arXiv",
"PubMed"
],
"company_data": [
"Internal documents",
"Customer support logs",
"Product documentation",
"Sales transcripts",
"Internal wikis"
],
"user_generated": [
"Social media",
"Forums",
"Reviews",
"Q&A sites"
],
"synthetic": [
"LLM-generated",
"Template-based",
"Rule-based"
]
}
def identify_sources(self, domain: str,
task: str) -> List[str]:
relevant_sources = []
if domain == "medical":
relevant_sources.extend([
"PubMed",
"Medical journals",
"Clinical guidelines"
])
elif domain == "legal":
relevant_sources.extend([
"Legal cases",
"Law reviews",
"Regulations"
])
elif domain == "finance":
relevant_sources.extend([
"Financial reports",
"Market data",
"Economic indicators"
])
if task == "question_answering":
relevant_sources.extend([
"Q&A datasets",
"FAQs",
"Knowledge bases"
])
elif task == "summarization":
relevant_sources.extend([
"Articles",
"Reports",
"Documents"
])
return relevant_sources数据收集器
python
import requests
from pathlib import Path
class DataCollector:
def __init__(self, config: Dict):
self.config = config
self.collected_data = []
def collect_from_huggingface(self, dataset_name: str) -> List[Dict]:
try:
from datasets import load_dataset
dataset = load_dataset(dataset_name)
collected = []
for split in dataset.keys():
for item in dataset[split]:
collected.append({
"source": "huggingface",
"dataset": dataset_name,
"split": split,
"data": item
})
self.collected_data.extend(collected)
return collected
except ImportError:
raise ImportError("Install datasets: pip install datasets")
def collect_from_api(self, api_url: str,
params: Dict = None) -> List[Dict]:
response = requests.get(api_url, params=params)
if response.status_code == 200:
data = response.json()
collected = []
for item in data:
collected.append({
"source": "api",
"url": api_url,
"data": item
})
self.collected_data.extend(collected)
return collected
else:
raise Exception(f"API request failed: {response.status_code}")
def collect_from_files(self, directory: str,
file_patterns: List[str] = None) -> List[Dict]:
path = Path(directory)
if not path.exists():
raise FileNotFoundError(f"Directory not found: {directory}")
if file_patterns is None:
file_patterns = ["*.txt", "*.json", "*.csv"]
collected = []
for pattern in file_patterns:
for file_path in path.rglob(pattern):
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
collected.append({
"source": "file",
"path": str(file_path),
"data": {"content": content}
})
except Exception as e:
print(f"Error reading {file_path}: {e}")
self.collected_data.extend(collected)
return collected
def collect_from_web(self, urls: List[str]) -> List[Dict]:
collected = []
for url in urls:
try:
response = requests.get(url)
if response.status_code == 200:
from bs4 import BeautifulSoup
soup = BeautifulSoup(response.text, 'html.parser')
content = soup.get_text()
collected.append({
"source": "web",
"url": url,
"data": {"content": content}
})
except Exception as e:
print(f"Error fetching {url}: {e}")
self.collected_data.extend(collected)
return collected数据合成
python
class DataSynthesizer:
def __init__(self, llm):
self.llm = llm
def synthesize_qa_pairs(self, topic: str,
n_pairs: int = 10) -> List[Dict]:
prompt = f"""
Generate {n_pairs} question-answer pairs about {topic}.
Format each pair as:
Q: [question]
A: [answer]
Make questions diverse and challenging.
"""
response = self.llm.generate(prompt)
return self._parse_qa_pairs(response)
def synthesize_instruction_data(self, task: str,
n_examples: int = 10) -> List[Dict]:
prompt = f"""
Generate {n_examples} examples for the task: {task}
Format each example as:
Instruction: [instruction]
Input: [input]
Output: [output]
Make examples diverse and realistic.
"""
response = self.llm.generate(prompt)
return self._parse_instruction_data(response)
def synthesize_conversations(self, scenario: str,
n_conversations: int = 5) -> List[Dict]:
prompt = f"""
Generate {n_conversations} conversations for the scenario: {scenario}
Format each conversation as:
User: [user message]
Assistant: [assistant response]
Make conversations natural and helpful.
"""
response = self.llm.generate(prompt)
return self._parse_conversations(response)
def _parse_qa_pairs(self, response: str) -> List[Dict]:
pairs = []
lines = response.split('\n')
current_pair = {}
for line in lines:
line = line.strip()
if line.startswith('Q:'):
current_pair = {"question": line[2:].strip()}
elif line.startswith('A:'):
current_pair["answer"] = line[2:].strip()
pairs.append(current_pair)
return pairs
def _parse_instruction_data(self, response: str) -> List[Dict]:
examples = []
lines = response.split('\n')
current_example = {}
for line in lines:
line = line.strip()
if line.startswith('Instruction:'):
current_example = {"instruction": line[12:].strip()}
elif line.startswith('Input:'):
current_example["input"] = line[6:].strip()
elif line.startswith('Output:'):
current_example["output"] = line[7:].strip()
examples.append(current_example)
return examples
def _parse_conversations(self, response: str) -> List[Dict]:
conversations = []
current_conversation = []
lines = response.split('\n')
for line in lines:
line = line.strip()
if line.startswith('User:'):
current_conversation.append({
"role": "user",
"content": line[5:].strip()
})
elif line.startswith('Assistant:'):
current_conversation.append({
"role": "assistant",
"content": line[11:].strip()
})
if current_conversation:
conversations.append({"messages": current_conversation})
return conversations数据清洗
文本清洗
python
import re
class TextCleaner:
def __init__(self, config: Dict = None):
self.config = config or {}
def clean(self, text: str) -> str:
text = self._remove_html_tags(text)
text = self._remove_urls(text)
text = self._remove_emails(text)
text = self._remove_phone_numbers(text)
text = self._normalize_whitespace(text)
text = self._remove_special_chars(text)
return text
def _remove_html_tags(self, text: str) -> str:
clean = re.compile('<.*?>')
return re.sub(clean, '', text)
def _remove_urls(self, text: str) -> str:
url_pattern = re.compile(r'http\S+|www\.\S+')
return url_pattern.sub('', text)
def _remove_emails(self, text: str) -> str:
email_pattern = re.compile(r'\S+@\S+')
return email_pattern.sub('', text)
def _remove_phone_numbers(self, text: str) -> str:
phone_pattern = re.compile(r'\d{3}[-.\s]?\d{3}[-.\s]?\d{4}')
return phone_pattern.sub('', text)
def _normalize_whitespace(self, text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def _remove_special_chars(self, text: str) -> str:
special_chars = self.config.get("special_chars_to_remove", [])
for char in special_chars:
text = text.replace(char, '')
return text数据去重
python
class DataDeduplicator:
def __init__(self):
pass
def deduplicate_by_hash(self, data: List[Dict]) -> List[Dict]:
seen_hashes = set()
deduplicated = []
for item in data:
text = str(item.get("data", item))
text_hash = hash(text)
if text_hash not in seen_hashes:
seen_hashes.add(text_hash)
deduplicated.append(item)
return deduplicated
def deduplicate_by_similarity(self, data: List[Dict],
threshold: float = 0.95) -> List[Dict]:
deduplicated = []
for item in data:
is_duplicate = False
for existing in deduplicated:
similarity = self._calculate_similarity(
str(item.get("data", item)),
str(existing.get("data", existing))
)
if similarity > threshold:
is_duplicate = True
break
if not is_duplicate:
deduplicated.append(item)
return deduplicated
def _calculate_similarity(self, text1: str,
text2: str) -> float:
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
intersection = words1 & words2
union = words1 | words2
if not union:
return 0.0
return len(intersection) / len(union)数据过滤
python
class DataFilter:
def __init__(self, config: Dict = None):
self.config = config or {}
def filter_by_length(self, data: List[Dict],
min_length: int = 10,
max_length: int = 10000) -> List[Dict]:
filtered = []
for item in data:
text = str(item.get("data", item))
text_length = len(text.split())
if min_length <= text_length <= max_length:
filtered.append(item)
return filtered
def filter_by_language(self, data: List[Dict],
language: str = "en") -> List[Dict]:
try:
from langdetect import detect
except ImportError:
raise ImportError("Install langdetect: pip install langdetect")
filtered = []
for item in data:
text = str(item.get("data", item))
try:
detected_lang = detect(text)
if detected_lang == language:
filtered.append(item)
except:
pass
return filtered
def filter_by_quality(self, data: List[Dict],
min_quality_score: float = 0.5) -> List[Dict]:
filtered = []
for item in data:
quality_score = self._assess_quality(item)
if quality_score >= min_quality_score:
filtered.append(item)
return filtered
def _assess_quality(self, item: Dict) -> float:
text = str(item.get("data", item))
scores = []
scores.append(self._check_length_quality(text))
scores.append(self._check_structure_quality(text))
scores.append(self._check_content_quality(text))
return sum(scores) / len(scores)
def _check_length_quality(self, text: str) -> float:
length = len(text.split())
if 50 <= length <= 500:
return 1.0
elif 10 <= length < 50 or 500 < length <= 1000:
return 0.7
else:
return 0.3
def _check_structure_quality(self, text: str) -> float:
if not text:
return 0.0
has_sentences = any(char in text for char in '.!?')
has_paragraphs = '\n\n' in text
score = 0.5
if has_sentences:
score += 0.3
if has_paragraphs:
score += 0.2
return score
def _check_content_quality(self, text: str) -> float:
words = text.split()
if not words:
return 0.0
avg_word_length = sum(len(word) for word in words) / len(words)
if 3 <= avg_word_length <= 6:
return 1.0
elif 2 <= avg_word_length < 3 or 6 < avg_word_length <= 8:
return 0.7
else:
return 0.3数据格式化
指令格式化
python
class InstructionFormatter:
def __init__(self, template: str = None):
self.template = template or self._default_template()
def _default_template(self) -> str:
return """Instruction: {instruction}
Input: {input}
Output: {output}"""
def format(self, data: Dict) -> str:
instruction = data.get("instruction", "")
input_text = data.get("input", "")
output_text = data.get("output", "")
formatted = self.template.format(
instruction=instruction,
input=input_text,
output=output_text
)
return formatted
def format_batch(self, data_list: List[Dict]) -> List[str]:
return [self.format(data) for data in data_list]对话格式化
python
class ConversationFormatter:
def __init__(self, template: str = None):
self.template = template or self._default_template()
def _default_template(self) -> str:
return """{messages}"""
def format(self, conversation: Dict) -> str:
messages = conversation.get("messages", [])
formatted_messages = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
formatted_messages.append(f"{role}: {content}")
formatted = self.template.format(
messages="\n".join(formatted_messages)
)
return formatted
def format_batch(self, conversations: List[Dict]) -> List[str]:
return [self.format(conv) for conv in conversations]QA格式化
python
class QAFormatter:
def __init__(self, template: str = None):
self.template = template or self._default_template()
def _default_template(self) -> str:
return """Question: {question}
Answer: {answer}"""
def format(self, qa_pair: Dict) -> str:
question = qa_pair.get("question", "")
answer = qa_pair.get("answer", "")
formatted = self.template.format(
question=question,
answer=answer
)
return formatted
def format_batch(self, qa_pairs: List[Dict]) -> List[str]:
return [self.format(qa) for qa in qa_pairs]数据增强
同义词替换
python
import random
class SynonymAugmenter:
def __init__(self):
self.synonym_dict = self._build_synonym_dict()
def _build_synonym_dict(self) -> Dict:
return {
"good": ["excellent", "great", "wonderful", "amazing"],
"bad": ["terrible", "awful", "poor", "horrible"],
"big": ["large", "huge", "enormous", "massive"],
"small": ["tiny", "little", "mini", "compact"],
"fast": ["quick", "rapid", "swift", "speedy"],
"slow": ["sluggish", "leisurely", "unhurried", "gradual"]
}
def augment(self, text: str,
n_augmentations: int = 1) -> List[str]:
augmented_texts = [text]
for _ in range(n_augmentations):
words = text.split()
augmented_words = []
for word in words:
if word.lower() in self.synonym_dict:
synonyms = self.synonym_dict[word.lower()]
augmented_word = random.choice(synonyms)
augmented_words.append(augmented_word)
else:
augmented_words.append(word)
augmented_text = " ".join(augmented_words)
augmented_texts.append(augmented_text)
return augmented_texts回译增强
python
class BackTranslationAugmenter:
def __init__(self, translator):
self.translator = translator
def augment(self, text: str,
intermediate_lang: str = "fr") -> List[str]:
translated = self.translator.translate(text, target_lang=intermediate_lang)
back_translated = self.translator.translate(translated, target_lang="en")
return [text, back_translated]随机插入
python
class RandomInsertionAugmenter:
def __init__(self, insertion_words: List[str] = None):
self.insertion_words = insertion_words or [
"actually", "basically", "essentially", "really",
"very", "quite", "rather", "somewhat"
]
def augment(self, text: str,
n_augmentations: int = 1) -> List[str]:
augmented_texts = [text]
for _ in range(n_augmentations):
words = text.split()
augmented_words = words.copy()
n_insertions = random.randint(1, 3)
for _ in range(n_insertions):
insert_pos = random.randint(0, len(augmented_words))
insert_word = random.choice(self.insertion_words)
augmented_words.insert(insert_pos, insert_word)
augmented_text = " ".join(augmented_words)
augmented_texts.append(augmented_text)
return augmented_texts数据质量评估
质量指标
python
class DataQualityAssessor:
def __init__(self):
pass
def assess(self, data: List[Dict]) -> Dict:
metrics = {}
metrics["length_distribution"] = self._assess_length_distribution(data)
metrics["language_distribution"] = self._assess_language_distribution(data)
metrics["quality_scores"] = self._assess_quality_scores(data)
metrics["diversity"] = self._assess_diversity(data)
return metrics
def _assess_length_distribution(self, data: List[Dict]) -> Dict:
lengths = [len(str(item.get("data", item)).split()) for item in data]
return {
"mean": sum(lengths) / len(lengths),
"median": sorted(lengths)[len(lengths) // 2],
"min": min(lengths),
"max": max(lengths),
"std": self._std_dev(lengths)
}
def _assess_language_distribution(self, data: List[Dict]) -> Dict:
try:
from langdetect import detect
except ImportError:
return {"error": "langdetect not installed"}
languages = {}
for item in data:
text = str(item.get("data", item))
try:
lang = detect(text)
languages[lang] = languages.get(lang, 0) + 1
except:
pass
return languages
def _assess_quality_scores(self, data: List[Dict]) -> Dict:
scores = []
for item in data:
score = self._assess_single_quality(item)
scores.append(score)
return {
"mean": sum(scores) / len(scores),
"min": min(scores),
"max": max(scores)
}
def _assess_single_quality(self, item: Dict) -> float:
text = str(item.get("data", item))
score = 0.0
if len(text.split()) >= 10:
score += 0.3
if any(char in text for char in '.!?'):
score += 0.3
if len(text.split('\n')) > 1:
score += 0.2
if len(text) > 100:
score += 0.2
return score
def _assess_diversity(self, data: List[Dict]) -> Dict:
all_words = set()
for item in data:
text = str(item.get("data", item))
words = set(text.lower().split())
all_words.update(words)
return {
"unique_words": len(all_words),
"total_samples": len(data),
"avg_unique_per_sample": len(all_words) / len(data)
}
def _std_dev(self, values: List[float]) -> float:
if len(values) < 2:
return 0.0
mean = sum(values) / len(values)
variance = sum((x - mean) ** 2 for x in values) / len(values)
return variance ** 0.5实践练习
练习1:实现完整的数据准备流程
python
class DataPreparationPipeline:
def __init__(self, config):
self.config = config
self.collector = DataCollector(config)
self.cleaner = TextCleaner(config)
self.deduplicator = DataDeduplicator()
self.filter = DataFilter(config)
self.formatter = InstructionFormatter()
def prepare(self, sources: List[str]) -> List[str]:
collected = []
for source in sources:
if source.startswith("http"):
data = self.collector.collect_from_api(source)
else:
data = self.collector.collect_from_files(source)
collected.extend(data)
cleaned = []
for item in collected:
text = str(item.get("data", item))
cleaned_text = self.cleaner.clean(text)
item["data"] = {"content": cleaned_text}
cleaned.append(item)
deduplicated = self.deduplicator.deduplicate_by_hash(cleaned)
filtered = self.filter.filter_by_length(deduplicated)
formatted = self.formatter.format_batch(filtered)
return formatted总结
本节我们学习了微调数据准备:
- 数据收集方法(数据源识别、收集器、合成)
- 数据清洗技术(文本清洗、去重、过滤)
- 数据格式化(指令、对话、QA)
- 数据增强方法(同义词、回译、随机插入)
- 数据质量评估(长度、语言、质量、多样性)
高质量的数据是微调成功的关键。
