Spaces:
Sleeping
Sleeping
lbtwyk
commited on
Commit
·
f46834b
1
Parent(s):
2adb71d
Remove natural language input mode and simplify prompt building
Browse files
app.py
CHANGED
|
@@ -3,22 +3,10 @@ from pydantic import BaseModel
|
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
|
| 5 |
import json
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
from peft import PeftModel
|
| 12 |
|
| 13 |
-
ROOT_DIR = Path(__file__).resolve().parent
|
| 14 |
-
RL_DIR = ROOT_DIR / "RL"
|
| 15 |
-
for path in (ROOT_DIR, RL_DIR):
|
| 16 |
-
path_str = str(path)
|
| 17 |
-
if path_str not in sys.path:
|
| 18 |
-
sys.path.append(path_str)
|
| 19 |
-
|
| 20 |
-
from RL.battleground_nl_utils import game_state_to_natural_language
|
| 21 |
-
|
| 22 |
|
| 23 |
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
| 24 |
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
|
|
@@ -50,55 +38,30 @@ Rules:
|
|
| 50 |
Now here is the game state JSON:
|
| 51 |
"""
|
| 52 |
|
| 53 |
-
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
|
| 54 |
-
Given the following natural language description of the current game state, choose
|
| 55 |
-
the best full-turn sequence of actions and respond with a single JSON object in
|
| 56 |
-
this exact format:
|
| 57 |
-
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
|
| 58 |
-
Rules:
|
| 59 |
-
1. Respond with JSON only. Do not add explanations or any extra text.
|
| 60 |
-
2. The top-level object must have exactly one key: "actions".
|
| 61 |
-
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
|
| 62 |
-
atomic action objects.
|
| 63 |
-
4. Use 0-based integers for indices or null when not used.
|
| 64 |
-
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
|
| 65 |
-
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
|
| 66 |
-
6. "card_name" must exactly match a card name from the game state when required,
|
| 67 |
-
otherwise null.
|
| 68 |
-
Now here is the description of the game state:
|
| 69 |
-
"""
|
| 70 |
|
| 71 |
|
| 72 |
class GenerateRequest(BaseModel):
|
| 73 |
phase: Optional[str] = None
|
| 74 |
turn: Optional[int] = None
|
| 75 |
state: Dict[str, Any]
|
| 76 |
-
input_mode: str = "json" # "json" or "nl"
|
| 77 |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
|
| 78 |
temperature: float = DEFAULT_TEMPERATURE
|
| 79 |
|
| 80 |
|
| 81 |
-
def build_prompt(example: Dict[str, Any]
|
|
|
|
| 82 |
state = example.get("state", {}) or {}
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
"phase": phase,
|
| 95 |
-
"turn": turn,
|
| 96 |
-
"state": state,
|
| 97 |
-
}
|
| 98 |
-
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
|
| 99 |
-
prefix = INSTRUCTION_PREFIX
|
| 100 |
-
|
| 101 |
-
return prefix + "\n" + state_text
|
| 102 |
|
| 103 |
|
| 104 |
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
|
|
@@ -199,7 +162,7 @@ def generate_actions(req: GenerateRequest):
|
|
| 199 |
"turn": req.turn,
|
| 200 |
"state": req.state,
|
| 201 |
}
|
| 202 |
-
prompt = build_prompt(example
|
| 203 |
|
| 204 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 205 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
|
| 5 |
import json
|
|
|
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from peft import PeftModel
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
| 12 |
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
|
|
|
|
| 38 |
Now here is the game state JSON:
|
| 39 |
"""
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class GenerateRequest(BaseModel):
|
| 44 |
phase: Optional[str] = None
|
| 45 |
turn: Optional[int] = None
|
| 46 |
state: Dict[str, Any]
|
|
|
|
| 47 |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
|
| 48 |
temperature: float = DEFAULT_TEMPERATURE
|
| 49 |
|
| 50 |
|
| 51 |
+
def build_prompt(example: Dict[str, Any]) -> str:
|
| 52 |
+
"""Build a JSON-mode prompt (the only mode supported by this Space)."""
|
| 53 |
state = example.get("state", {}) or {}
|
| 54 |
+
gs = state.get("game_state", {}) or {}
|
| 55 |
+
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
|
| 56 |
+
turn = example.get("turn", gs.get("turn_number", 0))
|
| 57 |
+
obj = {
|
| 58 |
+
"task": "battlegrounds_policy_v1",
|
| 59 |
+
"phase": phase,
|
| 60 |
+
"turn": turn,
|
| 61 |
+
"state": state,
|
| 62 |
+
}
|
| 63 |
+
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
|
| 64 |
+
return INSTRUCTION_PREFIX + "\n" + state_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
|
|
|
|
| 162 |
"turn": req.turn,
|
| 163 |
"state": req.state,
|
| 164 |
}
|
| 165 |
+
prompt = build_prompt(example)
|
| 166 |
|
| 167 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 168 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|