网站首页 > 知识剖析 正文
一、项目背景
在医疗领域,利用人工智能辅助诊断和分析是当下的热门趋势。准确的医疗图像分析和医学文本理解对于疾病的诊断、治疗方案的制定至关重要。本项目旨在借助 Gemma3 多模态大模型,通过对医疗数据集的微调,构建一个能精准处理医疗图像与文本信息的智能模型,为医疗工作者提供更有效的辅助。
二、核心技术
(一)QLoRA 技术
QLoRA 是一种参数高效的微调技术,在医疗模型微调中发挥着重要作用。它可以对大模型进行 4 位量化训练,大大减少了显存的使用,同时通过训练低秩自适应(LoRA)适配器来调整模型参数,在不显著增加计算资源的情况下,让模型更好地适应医疗领域的特定任务。
(二)相关库的使用
结合 Hugging Face Transformers 和 TRL 库,能够方便地加载预训练模型、处理数据集以及进行模型的微调。这些库提供了丰富的工具和接口,使得整个微调过程更加高效和便捷。
三、数据准备
(一)数据集选择
我们使用放射科专用数据集,该数据集包含了大量的医学图像以及对应的专业诊断描述。医学图像是医生诊断疾病的重要依据,不同的疾病在图像上会呈现出不同的特征,例如肺部疾病可能会有阴影,冠状动脉造影显示右冠状动脉降支夹层等表现。而对应的诊断描述则是医生根据图像信息给出的专业判断,包含了疾病的名称、可能的病因、严重程度等关键信息。
(二)数据预处理
1.图像标准化:将医学图像(如 X 光片)转换成 RGB 格式。在医疗图像分析中,统一的图像格式有助于模型更稳定地学习图像特征,避免因格式差异导致的学习困难。
2.构建对话模板:构建符合模型输入要求的对话模板。例如:
3.数据格式转换:把数据变成模型能看懂的对话格式,使模型能够更好地处理多模态信息。
四、使用 TRL 和 SFTTrainer 微调 gemma3
现在,您可以对模型进行微调了。借助 Hugging Face TRL SFTTrainer,您可以轻松监督微调开放式 LLM。SFTTrainer 是 transformers 库中的 Trainer 的子类,支持所有相同的功能(包括日志记录、评估和检查点),但还添加了其他实用功能,包括:
1.数据集格式设置,包括对话格式和指令格式
2.仅根据完成情况进行训练,忽略提示
3.打包数据集以提高训练效率
4.支持参数高效微调 (PEFT),包括 QloRA
5.准备模型和分词器以进行对话式微调(例如添加特殊标记)
以下代码会从 Hugging Face 加载 Gemma 模型和分词器,并初始化量化配置。
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
# Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-27b-it")
在展开训练伊始,您务必要明确需于 SFTConfig 中运用的超参数,还有用于处置视觉处理的自定义 collate_fn。collate_fn 能够把涵盖文本和图片的消息转化成模型能够领会的格式。
from trl import SFTConfig
args = SFTConfig(
output_dir="gemma-product-description", # 保存模型的目录以及仓库ID
num_train_epochs=1, # 训练的轮数
per_device_train_batch_size=1, # 训练期间每个设备的批次大小
gradient_accumulation_steps=4, # 执行反向传播/更新操作前的步数
gradient_checkpointing=True, # 使用梯度检查点以节省内存
optim="adamw_torch_fused", # 使用融合的adamw优化器
logging_steps=5, # 每5步记录一次日志
save_strategy="epoch", # 每个训练轮次保存一次检查点
learning_rate=2e-4, # 学习率,基于QLoRA论文设置
bf16=True, # 使用bfloat16精度
max_grad_norm=0.3, # 最大梯度范数,基于QLoRA论文设置
warmup_ratio=0.03, # 热身比率,基于QLoRA论文设置
lr_scheduler_type="constant", # 使用恒定学习率调度器
push_to_hub=True, # 将模型推送到Hugging Face Hub
report_to="tensorboard", # 将指标报告给TensorBoard
gradient_checkpointing_kwargs={
"use_reentrant": False
}, # 使用非重入式检查点
dataset_text_field="", # 数据整理器需要一个虚拟字段
dataset_kwargs={"skip_prepare_dataset": True}, # 对数据整理器很重要
)
args.remove_unused_columns = False # 对数据整理器很重要
# 创建一个数据整理器来编码文本和图像对
def collate_fn(examples):
texts = []
images = []
for example in examples:
image_inputs = process_vision_info(example["messages"])
text = processor.apply_chat_template(
example["messages"], add_generation_prompt=False, tokenize=False
)
texts.append(text.strip())
images.append(image_inputs)
# 对文本进行分词并处理图像
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# 标签就是输入ID,并且在损失计算中屏蔽填充标记和图像标记
labels = batch["input_ids"].clone()
# 屏蔽图像标记
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
# 屏蔽在损失计算中不使用的标记
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
五、训练模型
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
)
trainer.train()
trainer.save_model()
#训练之后释放模型内存
del model
del trainer
torch.cuda.empty_cache()
六、测试模型
训练完成后,您需要评估和测试模型。您可以从测试数据集中加载不同的样本,并针对这些样本评估模型。
def generate_description(sample, model, processor):
messages = [
{ "role": "user",
"content" : [
{"type" : "text", "text" : system_message},
{"type" : "image", "image" : image} ]
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8,
eos_token_id=stop_token_ids, disable_compile=True)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
description = generate_description(sample, model, processor)
print(description)
七、注意事项:AI 医生的 "实习禁忌"
1.显存监控:看片太投入容易 "烧显卡",建议搭配显存监控工具使用。
2.数据质量:模糊 X 光片会让模型变成 "睁眼瞎",数据清洗很重要。
3.保持谦逊:诊断结果仅供参考,人类医生才是最终决策者。
4.显卡要求:A40,显存48G,数据集使用800张影片,一个小时左右微调完成。
八、未来展望:让 AI 成为 "放射科瑞士军刀"
1.三维影像分析:让模型读懂 CT/MRI,解锁 3D 诊断技能。
2.多模态融合:结合病历数据,实现 "图像 + 文本" 联合诊断。
最后附完整代码:
from huggingface_hub import login
login("xxxxxxxxxxxxxxxxxxxxx")
from datasets import load_dataset
from PIL import Image
system_message = "You are an expert radiographer. Describe accurately what you see in this image."
def format_data(sample):
conversation = [
{ "role": "user",
"content" : [
{"type" : "text", "text" : system_message},
{"type" : "image", "image" : sample["image"]} ]
},
{ "role" : "assistant",
"content" : [
{"type" : "text", "text" : sample["caption"]} ]
},
]
return { "messages" : conversation }
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
# Iterate through each conversation
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
# Get the image and convert to RGB
if "image" in element:
image = element["image"]
else:
image = element
image_inputs.append(image.convert("RGB"))
return image_inputs
dataset = load_dataset("unsloth/Radiology_mini", split="train")
image_id = dataset[0]['image']
result = []
i = 1
for sample in dataset:
if i>800:
break
result.append(format_data(sample))
i = i+1
dataset = result
print(dataset[345]["messages"])
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
model_id = "/root/google/gemma-3-27-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
if torch.cuda.get_device_capability()[0] < 8:
raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
model_kwargs = dict(
attn_implementation="eager",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-27b-it")
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
from trl import SFTConfig
args = SFTConfig(
output_dir="gemma-product-description", # 保存模型的目录以及仓库ID
num_train_epochs=1, # 训练的轮数
per_device_train_batch_size=1, # 训练期间每个设备的批次大小
gradient_accumulation_steps=4, # 执行反向传播/更新操作前的步数
gradient_checkpointing=True, # 使用梯度检查点以节省内存
optim="adamw_torch_fused", # 使用融合的adamw优化器
logging_steps=5, # 每5步记录一次日志
save_strategy="epoch", # 每个训练轮次保存一次检查点
learning_rate=2e-4, # 学习率,基于QLoRA论文设置
bf16=True, # 使用bfloat16精度
max_grad_norm=0.3, # 最大梯度范数,基于QLoRA论文设置
warmup_ratio=0.03, # 热身比率,基于QLoRA论文设置
lr_scheduler_type="constant", # 使用恒定学习率调度器
push_to_hub=True, # 将模型推送到Hugging Face Hub
report_to="tensorboard", # 将指标报告给TensorBoard
gradient_checkpointing_kwargs={
"use_reentrant": False
}, # 使用非重入式检查点
dataset_text_field="", # 数据整理器需要一个虚拟字段
dataset_kwargs={"skip_prepare_dataset": True}, # 对数据整理器很重要
)
args.remove_unused_columns = False # 对数据整理器很重要
# 创建一个数据整理器来编码文本和图像对
def collate_fn(examples):
texts = []
images = []
for example in examples:
image_inputs = process_vision_info(example["messages"])
text = processor.apply_chat_template(
example["messages"], add_generation_prompt=False, tokenize=False
)
texts.append(text.strip())
images.append(image_inputs)
# 对文本进行分词并处理图像
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# 标签就是输入ID,并且在损失计算中屏蔽填充标记和图像标记
labels = batch["input_ids"].clone()
# 屏蔽图像标记
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
# 屏蔽在损失计算中不使用的标记
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
)
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
# Save the final model again to the Hugging Face Hub
trainer.save_model()
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
from peft import PeftModel
# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
import torch
model = AutoModelForImageTextToText.from_pretrained(
args.output_dir,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(args.output_dir)
import requests
from PIL import Image
image = image_id
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
def generate_description(sample, model, processor):
messages = [
{ "role": "user",
"content" : [
{"type" : "text", "text" : system_message},
{"type" : "image", "image" : image} ]
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8,
eos_token_id=stop_token_ids, disable_compile=True)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
description = generate_description(sample, model, processor)
print(description)
猜你喜欢
- 2025-05-23 HTML5语法中需要掌握的要点
- 2025-05-23 文档在线预览(三)使用前端实现word、excel、pdf、ppt 在线预览
- 2025-05-23 c# 10 教程:24 本机和 COM 互操作性
- 2025-05-23 《每日电讯报》研发数字工具,教你更有效率地报道新闻
- 2025-05-23 SpringBoot五步构建RAG服务:2025最新AI+向量数据库实战
- 2025-05-23 vLLM的参数列表及其中文说明
- 2025-05-23 15. LangChain多模态应用开发:融合文本、图像与语音
- 2025-05-23 用离散标记重塑人体姿态:VQ-VAE实现关键点组合关系编码
- 2025-05-23 144项大神级ppt制作技术
- 2025-05-23 前端分享-少年了解过iframe么
- 最近发表
- 标签列表
-
- xml (46)
- css animation (57)
- array_slice (60)
- htmlspecialchars (54)
- position: absolute (54)
- datediff函数 (47)
- array_pop (49)
- jsmap (52)
- toggleclass (43)
- console.time (63)
- .sql (41)
- ahref (40)
- js json.parse (59)
- html复选框 (60)
- css 透明 (44)
- css 颜色 (47)
- php replace (41)
- css nth-child (48)
- min-height (40)
- xml schema (44)
- css 最后一个元素 (46)
- location.origin (44)
- table border (49)
- html tr (40)
- video controls (49)