alignmentforever commited on
Commit
b19ec0f
·
verified ·
1 Parent(s): 304bac0

Upload multimodal_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. multimodal_inference.py +222 -0
multimodal_inference.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 多模态文档理解UI界面
6
+ 支持Align-DS-V模型的文档输入和System Prompt设置
7
+ 支持上传图像和文档文件(doc, docx, pdf等)
8
+ """
9
+
10
+ import os
11
+ import base64
12
+ import argparse
13
+ import tempfile
14
+ from typing import List, Dict, Any, Union
15
+ import gradio as gr
16
+ from openai import OpenAI
17
+ import docx
18
+ import fitz # PyMuPDF
19
+
20
+ # 默认系统提示词
21
+ DEFAULT_SYSTEM_PROMPT = """\
22
+ 你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医导大模型(Medguide-V),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。
23
+ **能力范围**
24
+ 1. 医学问答:
25
+ - 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。
26
+ - 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。
27
+ 2. 影像分析:
28
+ - 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。
29
+ - 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。
30
+ 3. 数据解析:
31
+ - 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。
32
+ - 你的任务是解释数据的含义,并提供可能的医学解读。
33
+ 4. 多模态能力:
34
+ - 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。
35
+ - 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。
36
+ **回答风格**
37
+ - 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。
38
+ - 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。
39
+ - 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。
40
+ """
41
+
42
+ class MultimodalAssistant:
43
+ def __init__(self, api_key="pku", api_base="http://0.0.0.0:8231/v1"):
44
+ # self.model = '0407_align_ds_v'
45
+ self.model = 'medguide-v'
46
+ self.client = OpenAI(api_key=api_key, base_url=api_base)
47
+
48
+ def encode_image(self, image_path):
49
+ with open(image_path, 'rb') as file:
50
+ return base64.b64encode(file.read()).decode('utf-8')
51
+
52
+ def extract_document_content(self, file_path):
53
+ result = {'text': '', 'images': []}
54
+ file_ext = os.path.splitext(file_path)[1].lower()
55
+
56
+ if file_ext in ['.doc', '.docx']:
57
+ doc = docx.Document(file_path)
58
+ result['text'] = '\n\n'.join([para.text for para in doc.paragraphs if para.text.strip()])
59
+
60
+ for rel in doc.part.rels.values():
61
+ if "image" in rel.target_ref:
62
+ try:
63
+ img_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
64
+ img_temp.write(rel.target_part.blob)
65
+ img_temp.close()
66
+ result['images'].append(img_temp.name)
67
+ except: pass
68
+
69
+ elif file_ext == '.pdf':
70
+ pdf_document = fitz.open(file_path)
71
+ result['text'] = '\n\n'.join([page.get_text() for page in pdf_document])
72
+
73
+ for page_num in range(len(pdf_document)):
74
+ page = pdf_document[page_num]
75
+ img_path = f"{file_path}_page{page_num+1}.png"
76
+ page.get_pixmap().save(img_path)
77
+ result['images'].append(img_path)
78
+ else:
79
+ result['images'].append(file_path)
80
+
81
+ # Limit to first 5 images
82
+ result['images'] = result['images'][:5]
83
+ return result
84
+
85
+ def text_conversation(self, text: str, role: str = 'user'):
86
+ return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}]
87
+
88
+ def image_conversation(self, image_base64: str, text: str = None):
89
+ return [
90
+ {
91
+ 'role': 'user',
92
+ 'content': [
93
+ {'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}},
94
+ {'type': 'text', 'text': text}
95
+ ]
96
+ }
97
+ ]
98
+
99
+ def process_conversation(self, system_prompt, message, history, files):
100
+ conversation = [{'role': 'system', 'content': system_prompt}]
101
+ for past_message in history:
102
+ role = past_message['role']
103
+ content = past_message['content']
104
+ if role == 'user':
105
+ if isinstance(content, str):
106
+ conversation.extend(self.text_conversation(content))
107
+ elif isinstance(content, tuple):
108
+ conversation.extend(self.image_conversation(content[0], content[1]))
109
+ else:
110
+ conversation.append({'role': role, 'content': content})
111
+
112
+ current_question = message['text'] if isinstance(message, dict) and 'text' in message else message
113
+
114
+ if not files:
115
+ conversation.append({'role': 'user', 'content': current_question})
116
+ else:
117
+ content = []
118
+ extracted_text = []
119
+
120
+ for file_path in files:
121
+ file_ext = os.path.splitext(file_path)[1].lower()
122
+
123
+ if file_ext in ['.doc', '.docx', '.pdf']:
124
+ doc_content = self.extract_document_content(file_path)
125
+
126
+ if doc_content['text']:
127
+ extracted_text.append(f"文档 '{os.path.basename(file_path)}' 内容:\n{doc_content['text']}")
128
+
129
+ for img_path in doc_content['images']:
130
+ content.append({
131
+ 'type': 'image_url',
132
+ 'image_url': {'url': f"data:image/jpeg;base64,{self.encode_image(img_path)}"}
133
+ })
134
+
135
+ if img_path.startswith(tempfile.gettempdir()) or img_path.startswith(f"{file_path}_page"):
136
+ try: os.remove(img_path)
137
+ except: pass
138
+ else:
139
+ content.append({
140
+ 'type': 'image_url',
141
+ 'image_url': {'url': f"data:image/jpeg;base64,{self.encode_image(file_path)}"}
142
+ })
143
+
144
+ combined_text = current_question
145
+ if extracted_text:
146
+ combined_text += "\n\n以下是文档内容参考:\n" + "\n\n".join(extracted_text)
147
+
148
+ content.append({'type': 'text', 'text': combined_text})
149
+ conversation.append({'role': 'user', 'content': content})
150
+
151
+ response = self.client.chat.completions.create(
152
+ model=self.model,
153
+ messages=conversation,
154
+ stream=False,
155
+ )
156
+
157
+ answer = response.choices[0].message.content
158
+
159
+ if "**Final Answer**" in answer:
160
+ reasoning, final_answer = answer.split("**Final Answer**", 1)
161
+ if len(reasoning) > 5:
162
+ answer = f"""🤔 思考过程:\n```\n{reasoning.strip()}\n```\n\n✨ 最终答案:\n{final_answer.strip()}"""
163
+
164
+ return answer
165
+
166
+ def create_ui():
167
+ assistant = MultimodalAssistant()
168
+
169
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
+ gr.Markdown("# Medguide-V Reasoning CLI")
171
+ gr.Markdown("Better life with Medguide-V.")
172
+
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ system_prompt = gr.Textbox(
176
+ label="系统提示词",
177
+ value=DEFAULT_SYSTEM_PROMPT,
178
+ lines=5
179
+ )
180
+
181
+ files_upload = gr.File(
182
+ label="上传文档或图片",
183
+ file_count="multiple",
184
+ type="filepath",
185
+ file_types=[".jpg", ".jpeg", ".png", ".pdf", ".doc", ".docx"]
186
+ )
187
+
188
+ with gr.Row():
189
+ clear_btn = gr.Button("清除对话")
190
+ example_btn = gr.Button("加载示例")
191
+
192
+ chat_interface = gr.ChatInterface(
193
+ fn=lambda message, history, files, sys_prompt: assistant.process_conversation(
194
+ sys_prompt, message, history, files
195
+ ),
196
+ type='messages',
197
+ additional_inputs=[files_upload, system_prompt],
198
+ examples=[
199
+ ["这份文档的主要内容是什么?", None, None, DEFAULT_SYSTEM_PROMPT],
200
+ ["分析这份文档的主要观点", None, None, DEFAULT_SYSTEM_PROMPT],
201
+ ["提取这份文档中的关键数据", None, None, DEFAULT_SYSTEM_PROMPT]
202
+ ]
203
+ )
204
+
205
+ clear_btn.click(lambda: None, None, chat_interface.chatbot, queue=False)
206
+ example_btn.click(
207
+ lambda: ["这是一个示例系统提示词,请根据文档内容进行详细分析,包括摘要、关键点和建议。", None, []],
208
+ None,
209
+ [system_prompt, chat_interface.chatbot, files_upload],
210
+ queue=False
211
+ )
212
+
213
+ return demo
214
+
215
+ if __name__ == "__main__":
216
+ parser = argparse.ArgumentParser(description="多模态文档理解UI界面")
217
+ parser.add_argument("--api_key", type=str, default="medguide-v")
218
+ parser.add_argument("--api_base", type=str, default="http://0.0.0.0:8231/v1")
219
+ parser.add_argument("--share", default=True, action="store_true")
220
+ args = parser.parse_args()
221
+
222
+ create_ui().launch(share=args.share)