Upload multimodal_inference.py with huggingface_hub
Browse files- multimodal_inference.py +222 -0
multimodal_inference.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
多模态文档理解UI界面
|
| 6 |
+
支持Align-DS-V模型的文档输入和System Prompt设置
|
| 7 |
+
支持上传图像和文档文件(doc, docx, pdf等)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import base64
|
| 12 |
+
import argparse
|
| 13 |
+
import tempfile
|
| 14 |
+
from typing import List, Dict, Any, Union
|
| 15 |
+
import gradio as gr
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
import docx
|
| 18 |
+
import fitz # PyMuPDF
|
| 19 |
+
|
| 20 |
+
# 默认系统提示词
|
| 21 |
+
DEFAULT_SYSTEM_PROMPT = """\
|
| 22 |
+
你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医导大模型(Medguide-V),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。
|
| 23 |
+
**能力范围**
|
| 24 |
+
1. 医学问答:
|
| 25 |
+
- 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。
|
| 26 |
+
- 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。
|
| 27 |
+
2. 影像分析:
|
| 28 |
+
- 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。
|
| 29 |
+
- 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。
|
| 30 |
+
3. 数据解析:
|
| 31 |
+
- 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。
|
| 32 |
+
- 你的任务是解释数据的含义,并提供可能的医学解读。
|
| 33 |
+
4. 多模态能力:
|
| 34 |
+
- 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。
|
| 35 |
+
- 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。
|
| 36 |
+
**回答风格**
|
| 37 |
+
- 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。
|
| 38 |
+
- 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。
|
| 39 |
+
- 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
class MultimodalAssistant:
|
| 43 |
+
def __init__(self, api_key="pku", api_base="http://0.0.0.0:8231/v1"):
|
| 44 |
+
# self.model = '0407_align_ds_v'
|
| 45 |
+
self.model = 'medguide-v'
|
| 46 |
+
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
| 47 |
+
|
| 48 |
+
def encode_image(self, image_path):
|
| 49 |
+
with open(image_path, 'rb') as file:
|
| 50 |
+
return base64.b64encode(file.read()).decode('utf-8')
|
| 51 |
+
|
| 52 |
+
def extract_document_content(self, file_path):
|
| 53 |
+
result = {'text': '', 'images': []}
|
| 54 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
| 55 |
+
|
| 56 |
+
if file_ext in ['.doc', '.docx']:
|
| 57 |
+
doc = docx.Document(file_path)
|
| 58 |
+
result['text'] = '\n\n'.join([para.text for para in doc.paragraphs if para.text.strip()])
|
| 59 |
+
|
| 60 |
+
for rel in doc.part.rels.values():
|
| 61 |
+
if "image" in rel.target_ref:
|
| 62 |
+
try:
|
| 63 |
+
img_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 64 |
+
img_temp.write(rel.target_part.blob)
|
| 65 |
+
img_temp.close()
|
| 66 |
+
result['images'].append(img_temp.name)
|
| 67 |
+
except: pass
|
| 68 |
+
|
| 69 |
+
elif file_ext == '.pdf':
|
| 70 |
+
pdf_document = fitz.open(file_path)
|
| 71 |
+
result['text'] = '\n\n'.join([page.get_text() for page in pdf_document])
|
| 72 |
+
|
| 73 |
+
for page_num in range(len(pdf_document)):
|
| 74 |
+
page = pdf_document[page_num]
|
| 75 |
+
img_path = f"{file_path}_page{page_num+1}.png"
|
| 76 |
+
page.get_pixmap().save(img_path)
|
| 77 |
+
result['images'].append(img_path)
|
| 78 |
+
else:
|
| 79 |
+
result['images'].append(file_path)
|
| 80 |
+
|
| 81 |
+
# Limit to first 5 images
|
| 82 |
+
result['images'] = result['images'][:5]
|
| 83 |
+
return result
|
| 84 |
+
|
| 85 |
+
def text_conversation(self, text: str, role: str = 'user'):
|
| 86 |
+
return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}]
|
| 87 |
+
|
| 88 |
+
def image_conversation(self, image_base64: str, text: str = None):
|
| 89 |
+
return [
|
| 90 |
+
{
|
| 91 |
+
'role': 'user',
|
| 92 |
+
'content': [
|
| 93 |
+
{'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}},
|
| 94 |
+
{'type': 'text', 'text': text}
|
| 95 |
+
]
|
| 96 |
+
}
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
def process_conversation(self, system_prompt, message, history, files):
|
| 100 |
+
conversation = [{'role': 'system', 'content': system_prompt}]
|
| 101 |
+
for past_message in history:
|
| 102 |
+
role = past_message['role']
|
| 103 |
+
content = past_message['content']
|
| 104 |
+
if role == 'user':
|
| 105 |
+
if isinstance(content, str):
|
| 106 |
+
conversation.extend(self.text_conversation(content))
|
| 107 |
+
elif isinstance(content, tuple):
|
| 108 |
+
conversation.extend(self.image_conversation(content[0], content[1]))
|
| 109 |
+
else:
|
| 110 |
+
conversation.append({'role': role, 'content': content})
|
| 111 |
+
|
| 112 |
+
current_question = message['text'] if isinstance(message, dict) and 'text' in message else message
|
| 113 |
+
|
| 114 |
+
if not files:
|
| 115 |
+
conversation.append({'role': 'user', 'content': current_question})
|
| 116 |
+
else:
|
| 117 |
+
content = []
|
| 118 |
+
extracted_text = []
|
| 119 |
+
|
| 120 |
+
for file_path in files:
|
| 121 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
| 122 |
+
|
| 123 |
+
if file_ext in ['.doc', '.docx', '.pdf']:
|
| 124 |
+
doc_content = self.extract_document_content(file_path)
|
| 125 |
+
|
| 126 |
+
if doc_content['text']:
|
| 127 |
+
extracted_text.append(f"文档 '{os.path.basename(file_path)}' 内容:\n{doc_content['text']}")
|
| 128 |
+
|
| 129 |
+
for img_path in doc_content['images']:
|
| 130 |
+
content.append({
|
| 131 |
+
'type': 'image_url',
|
| 132 |
+
'image_url': {'url': f"data:image/jpeg;base64,{self.encode_image(img_path)}"}
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
if img_path.startswith(tempfile.gettempdir()) or img_path.startswith(f"{file_path}_page"):
|
| 136 |
+
try: os.remove(img_path)
|
| 137 |
+
except: pass
|
| 138 |
+
else:
|
| 139 |
+
content.append({
|
| 140 |
+
'type': 'image_url',
|
| 141 |
+
'image_url': {'url': f"data:image/jpeg;base64,{self.encode_image(file_path)}"}
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
combined_text = current_question
|
| 145 |
+
if extracted_text:
|
| 146 |
+
combined_text += "\n\n以下是文档内容参考:\n" + "\n\n".join(extracted_text)
|
| 147 |
+
|
| 148 |
+
content.append({'type': 'text', 'text': combined_text})
|
| 149 |
+
conversation.append({'role': 'user', 'content': content})
|
| 150 |
+
|
| 151 |
+
response = self.client.chat.completions.create(
|
| 152 |
+
model=self.model,
|
| 153 |
+
messages=conversation,
|
| 154 |
+
stream=False,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
answer = response.choices[0].message.content
|
| 158 |
+
|
| 159 |
+
if "**Final Answer**" in answer:
|
| 160 |
+
reasoning, final_answer = answer.split("**Final Answer**", 1)
|
| 161 |
+
if len(reasoning) > 5:
|
| 162 |
+
answer = f"""🤔 思考过程:\n```\n{reasoning.strip()}\n```\n\n✨ 最终答案:\n{final_answer.strip()}"""
|
| 163 |
+
|
| 164 |
+
return answer
|
| 165 |
+
|
| 166 |
+
def create_ui():
|
| 167 |
+
assistant = MultimodalAssistant()
|
| 168 |
+
|
| 169 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 170 |
+
gr.Markdown("# Medguide-V Reasoning CLI")
|
| 171 |
+
gr.Markdown("Better life with Medguide-V.")
|
| 172 |
+
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column(scale=3):
|
| 175 |
+
system_prompt = gr.Textbox(
|
| 176 |
+
label="系统提示词",
|
| 177 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
| 178 |
+
lines=5
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
files_upload = gr.File(
|
| 182 |
+
label="上传文档或图片",
|
| 183 |
+
file_count="multiple",
|
| 184 |
+
type="filepath",
|
| 185 |
+
file_types=[".jpg", ".jpeg", ".png", ".pdf", ".doc", ".docx"]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
with gr.Row():
|
| 189 |
+
clear_btn = gr.Button("清除对话")
|
| 190 |
+
example_btn = gr.Button("加载示例")
|
| 191 |
+
|
| 192 |
+
chat_interface = gr.ChatInterface(
|
| 193 |
+
fn=lambda message, history, files, sys_prompt: assistant.process_conversation(
|
| 194 |
+
sys_prompt, message, history, files
|
| 195 |
+
),
|
| 196 |
+
type='messages',
|
| 197 |
+
additional_inputs=[files_upload, system_prompt],
|
| 198 |
+
examples=[
|
| 199 |
+
["这份文档的主要内容是什么?", None, None, DEFAULT_SYSTEM_PROMPT],
|
| 200 |
+
["分析这份文档的主要观点", None, None, DEFAULT_SYSTEM_PROMPT],
|
| 201 |
+
["提取这份文档中的关键数据", None, None, DEFAULT_SYSTEM_PROMPT]
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
clear_btn.click(lambda: None, None, chat_interface.chatbot, queue=False)
|
| 206 |
+
example_btn.click(
|
| 207 |
+
lambda: ["这是一个示例系统提示词,请根据文档内容进行详细分析,包括摘要、关键点和建议。", None, []],
|
| 208 |
+
None,
|
| 209 |
+
[system_prompt, chat_interface.chatbot, files_upload],
|
| 210 |
+
queue=False
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return demo
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
parser = argparse.ArgumentParser(description="多模态文档理解UI界面")
|
| 217 |
+
parser.add_argument("--api_key", type=str, default="medguide-v")
|
| 218 |
+
parser.add_argument("--api_base", type=str, default="http://0.0.0.0:8231/v1")
|
| 219 |
+
parser.add_argument("--share", default=True, action="store_true")
|
| 220 |
+
args = parser.parse_args()
|
| 221 |
+
|
| 222 |
+
create_ui().launch(share=args.share)
|