multimodalart HF Staff commited on
Commit
6e8ca12
·
verified ·
1 Parent(s): dc8a33c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import List, Tuple, Dict, Any
4
+ import spaces
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ # ----------------------
11
+ # Config
12
+ # ----------------------
13
+ MODEL_ID = os.getenv("MODEL_ID", "microsoft/UserLM-8b")
14
+ DEFAULT_SYSTEM_PROMPT = (
15
+ "You are a user who wants to implement a special type of sequence. "
16
+ "The sequence sums up the two previous numbers in the sequence and adds 1 to the result. "
17
+ "The first two numbers in the sequence are 1 and 1."
18
+ )
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ def load_model(model_id: str = MODEL_ID):
23
+ """Load tokenizer and model, with a reasonable dtype and device fallback."""
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
+
26
+ dtype = torch.float16 if device == "cuda" else torch.float32
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ trust_remote_code=True,
30
+ torch_dtype=dtype,
31
+ )
32
+
33
+ # Special tokens for stopping / filtering
34
+ end_token = "<|eot_id|>"
35
+ end_conv_token = "<|endconversation|>"
36
+ end_token_ids = tokenizer.encode(end_token, add_special_tokens=False)
37
+ end_conv_token_ids = tokenizer.encode(end_conv_token, add_special_tokens=False)
38
+
39
+ # Some models may not include these tokens — handle gracefully
40
+ eos_token_id = end_token_ids[0] if len(end_token_ids) > 0 else tokenizer.eos_token_id
41
+ bad_words_ids = (
42
+ [[tid] for tid in end_conv_token_ids] if len(end_conv_token_ids) > 0 else None
43
+ )
44
+
45
+ return tokenizer, model, eos_token_id, bad_words_ids
46
+
47
+
48
+ tokenizer, model, EOS_TOKEN_ID, BAD_WORDS_IDS = load_model()
49
+ model = model.to(device)
50
+ model.eval()
51
+
52
+ # ----------------------
53
+ # Generation helper
54
+ # ----------------------
55
+
56
+ def build_messages(system_prompt: str, history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
57
+ """Transform Gradio history [(user, assistant), ...] into chat template messages."""
58
+ messages: List[Dict[str, str]] = []
59
+ if system_prompt.strip():
60
+ messages.append({"role": "system", "content": system_prompt.strip()})
61
+ for user_msg, assistant_msg in history:
62
+ if user_msg:
63
+ messages.append({"role": "user", "content": user_msg})
64
+ if assistant_msg:
65
+ messages.append({"role": "assistant", "content": assistant_msg})
66
+ return messages
67
+
68
+
69
+ @spaces.GPU
70
+ def generate_reply(
71
+ messages: List[Dict[str, str]],
72
+ max_new_tokens: int = 256,
73
+ temperature: float = 0.8,
74
+ top_p: float = 0.9,
75
+ ) -> str:
76
+ """Run a single generate() step and return the model's text reply."""
77
+ # Prepare input ids using the model's chat template
78
+ inputs = tokenizer.apply_chat_template(
79
+ messages,
80
+ return_tensors="pt",
81
+ add_generation_prompt=True,
82
+ ).to(device)
83
+
84
+ with torch.no_grad():
85
+ outputs = model.generate(
86
+ input_ids=inputs,
87
+ do_sample=True,
88
+ top_p=top_p,
89
+ temperature=temperature,
90
+ max_new_tokens=max_new_tokens,
91
+ eos_token_id=EOS_TOKEN_ID,
92
+ pad_token_id=tokenizer.eos_token_id,
93
+ bad_words_ids=BAD_WORDS_IDS,
94
+ )
95
+
96
+ # Slice off the prompt tokens to get only the new text
97
+ generated = outputs[0][inputs.shape[1]:]
98
+ text = tokenizer.decode(generated, skip_special_tokens=True).strip()
99
+ return text
100
+
101
+
102
+ # ----------------------
103
+ # Gradio UI callbacks
104
+ # ----------------------
105
+
106
+ def respond(user_message: str, chat_history: List[Tuple[str, str]], system_prompt: str,
107
+ max_new_tokens: int, temperature: float, top_p: float):
108
+ # Build messages including prior turns
109
+ messages = build_messages(system_prompt, chat_history + [(user_message, "")])
110
+
111
+ try:
112
+ reply = generate_reply(
113
+ messages,
114
+ max_new_tokens=max_new_tokens,
115
+ temperature=temperature,
116
+ top_p=top_p,
117
+ )
118
+ except Exception as e:
119
+ reply = f"(Generation error: {e})"
120
+
121
+ chat_history = chat_history + [(user_message, reply)]
122
+ return chat_history, chat_history
123
+
124
+
125
+ def clear_state():
126
+ return [], DEFAULT_SYSTEM_PROMPT
127
+
128
+
129
+ # ----------------------
130
+ # Build the Gradio App
131
+ # ----------------------
132
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
+ gr.Markdown("""
134
+ # 🧪 Transformers × Gradio: Multi‑turn Chat Demo
135
+
136
+ Model: **{model}** on **{device}**
137
+
138
+ Change the system prompt, then chat. Sliders control sampling.
139
+ """.format(model=MODEL_ID, device=device))
140
+
141
+ with gr.Row():
142
+ system_box = gr.Textbox(
143
+ label="System Prompt",
144
+ value=DEFAULT_SYSTEM_PROMPT,
145
+ lines=3,
146
+ placeholder="Enter a system instruction to steer the assistant",
147
+ )
148
+
149
+ chatbot = gr.Chatbot(height=420, label="Chat")
150
+
151
+ with gr.Row():
152
+ msg = gr.Textbox(
153
+ label="Your message",
154
+ placeholder="Type a message and press Enter",
155
+ )
156
+
157
+ with gr.Accordion("Generation Settings", open=False):
158
+ max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
159
+ temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="temperature")
160
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
161
+
162
+ with gr.Row():
163
+ submit_btn = gr.Button("Send", variant="primary")
164
+ clear_btn = gr.Button("Clear")
165
+
166
+ state = gr.State([]) # chat history state: List[Tuple[user, assistant]]
167
+
168
+ def _submit(user_text, history, system_prompt, mnt, temp, tp):
169
+ if not user_text or not user_text.strip():
170
+ return gr.update(), history
171
+ new_history, visible = respond(user_text.strip(), history, system_prompt, mnt, temp, tp)
172
+ return "", visible
173
+
174
+ submit_btn.click(
175
+ fn=_submit,
176
+ inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
177
+ outputs=[msg, chatbot],
178
+ )
179
+ msg.submit(
180
+ fn=_submit,
181
+ inputs=[msg, state, system_box, max_new_tokens, temperature, top_p],
182
+ outputs=[msg, chatbot],
183
+ )
184
+
185
+ # Keep state in sync with the visible Chatbot
186
+ def _sync_state(chat):
187
+ return chat
188
+
189
+ chatbot.change(_sync_state, inputs=[chatbot], outputs=[state])
190
+
191
+ def _clear():
192
+ history, sys = clear_state()
193
+ return history, sys, history, ""
194
+
195
+ clear_btn.click(_clear, outputs=[state, system_box, chatbot, msg])
196
+
197
+ if __name__ == "__main__":
198
+ demo.queue().launch() # enable queuing for concurrency