RobertoBarrosoLuque commited on
Commit
5515ef5
·
1 Parent(s): 7a920b1

Add notebook with evals

Browse files
assets/Accuracy-precision-recall.png ADDED
assets/Accuracy.png ADDED
notebooks/02-model-evals.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/modules/evals.py CHANGED
@@ -6,80 +6,151 @@ from sklearn.metrics import (
6
  accuracy_score,
7
  classification_report,
8
  )
9
- from tqdm import tqdm
 
10
 
11
- from src.modules.vlm_inference import analyze_product_image
12
  from src.modules.data_processing import image_to_base64
 
13
 
 
14
 
15
- def run_inference_on_dataframe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  df: pd.DataFrame,
17
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
18
  api_key: Optional[str] = None,
19
  provider: str = "Fireworks",
20
- image_col: str = "image",
21
- id_col: str = "id",
22
  ) -> pd.DataFrame:
23
  """
24
- Run VLM inference on entire dataframe of images
25
 
26
  Args:
27
  df: DataFrame containing images
28
  model: Model to use for inference
29
  api_key: API key for the provider
30
  provider: Provider to use (Fireworks or OpenAI)
31
- image_col: Column name containing images
32
- id_col: Column name containing image IDs
33
 
34
  Returns:
35
  pd.DataFrame: Results with columns:
36
  - id: Image ID
37
- - pred_master_category: Predicted master category
38
  - pred_gender: Predicted gender
39
- - pred_sub_category: Predicted sub-category
40
  - pred_description: Predicted description
41
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  results = []
 
 
 
43
 
44
- for idx, row_id, row_image in tqdm(
45
- df.itertuples(index=True, name="columns"),
46
- total=len(df),
47
- desc="Running inference",
48
- ):
49
- try:
50
- img_b64 = image_to_base64(row_image)
51
 
52
- prediction = analyze_product_image(
53
- image_url=img_b64,
54
- model=model,
55
- api_key=api_key,
56
- provider=provider,
57
- )
58
 
59
- results.append(
60
- {
61
- "id": row_id,
62
- "pred_master_category": prediction.master_category,
63
- "pred_gender": prediction.gender,
64
- "pred_sub_category": prediction.sub_category,
65
- "pred_description": prediction.description,
66
- }
67
- )
68
 
69
- except Exception as e:
70
- print(f"Error processing row {idx} (ID: {row_id}): {e}")
71
- # Append placeholder for failed predictions
72
- results.append(
73
- {
74
- "id": row_id,
75
- "pred_master_category": None,
76
- "pred_gender": None,
77
- "pred_sub_category": None,
78
- "pred_description": f"Error: {str(e)}",
79
- }
80
- )
81
 
82
- return pd.DataFrame(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  def calculate_metrics(
@@ -225,3 +296,29 @@ def create_evaluation_summary(results: dict) -> pd.DataFrame:
225
  )
226
 
227
  return pd.DataFrame(summary_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  accuracy_score,
7
  classification_report,
8
  )
9
+ from tqdm.asyncio import tqdm as async_tqdm
10
+ import asyncio
11
 
12
+ from src.modules.vlm_inference import analyze_product_image_async
13
  from src.modules.data_processing import image_to_base64
14
+ from pathlib import Path
15
 
16
+ DATA_PATH = Path(__file__).parents[2] / "data"
17
 
18
+
19
+ async def _process_single_row(
20
+ row_data: dict,
21
+ model: str,
22
+ api_key: str,
23
+ provider: str,
24
+ semaphore: asyncio.Semaphore,
25
+ ) -> dict:
26
+ """
27
+ Process a single row with semaphore control
28
+
29
+ Args:
30
+ row_data: Dictionary with 'image' and 'id' keys
31
+ model: Model to use for inference
32
+ api_key: API key for the provider
33
+ provider: Provider to use
34
+ semaphore: Asyncio semaphore for rate limiting
35
+
36
+ Returns:
37
+ dict: Prediction result
38
+ """
39
+ async with semaphore:
40
+ try:
41
+ img_b64 = image_to_base64(row_data["image"])
42
+ prediction = await analyze_product_image_async(
43
+ image_url=img_b64,
44
+ model=model,
45
+ api_key=api_key,
46
+ provider=provider,
47
+ )
48
+ return {
49
+ "id": row_data["id"],
50
+ "pred_masterCategory": prediction.master_category,
51
+ "pred_gender": prediction.gender,
52
+ "pred_subCategory": prediction.sub_category,
53
+ "pred_description": prediction.description,
54
+ }
55
+ except Exception as e:
56
+ return {
57
+ "id": row_data["id"],
58
+ "pred_masterCategory": None,
59
+ "pred_gender": None,
60
+ "pred_subCategory": None,
61
+ "pred_description": f"Error: {str(e)}",
62
+ }
63
+
64
+
65
+ async def run_inference_on_dataframe_async(
66
  df: pd.DataFrame,
67
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
68
  api_key: Optional[str] = None,
69
  provider: str = "Fireworks",
70
+ max_concurrent_requests: int = 10,
 
71
  ) -> pd.DataFrame:
72
  """
73
+ Run VLM inference on entire dataframe of images with concurrent requests
74
 
75
  Args:
76
  df: DataFrame containing images
77
  model: Model to use for inference
78
  api_key: API key for the provider
79
  provider: Provider to use (Fireworks or OpenAI)
80
+ max_concurrent_requests: Maximum number of concurrent API requests (default: 10)
 
81
 
82
  Returns:
83
  pd.DataFrame: Results with columns:
84
  - id: Image ID
85
+ - pred_masterCategory: Predicted master category
86
  - pred_gender: Predicted gender
87
+ - pred_subCategory: Predicted sub-category
88
  - pred_description: Predicted description
89
  """
90
+ # Create semaphore for rate limiting
91
+ semaphore = asyncio.Semaphore(max_concurrent_requests)
92
+
93
+ # Prepare all rows as dictionaries
94
+ rows_data = [
95
+ {"image": row.image, "id": row.id}
96
+ for row in df.itertuples(index=False, name="columns")
97
+ ]
98
+
99
+ # Create all tasks (coroutines, not awaited yet)
100
+ tasks = [
101
+ _process_single_row(row_data, model, api_key, provider, semaphore)
102
+ for row_data in rows_data
103
+ ]
104
+
105
+ _model = model.split("/")[-1]
106
+ # Run all tasks concurrently with progress bar
107
  results = []
108
+ for task in async_tqdm.as_completed(tasks, total=len(tasks)):
109
+ result = await task
110
+ results.append(result)
111
 
112
+ if len(results) % 10 == 0:
113
+ df_pred = pd.DataFrame(results)
114
+ file_name = DATA_PATH / f"df_pred_{provider}_{_model}.csv"
115
+ df_pred.to_csv(file_name, index=False)
 
 
 
116
 
117
+ # Final save
118
+ df_pred = pd.DataFrame(results)
119
+ file_name = DATA_PATH / f"df_pred_{provider}_{_model}.csv"
120
+ df_pred.to_csv(file_name, index=False)
 
 
121
 
122
+ print(f"\nPrediction successful, dataset saved to {file_name}")
123
+ return df_pred
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ def run_inference_on_dataframe(
127
+ df: pd.DataFrame,
128
+ model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
129
+ api_key: Optional[str] = None,
130
+ provider: str = "Fireworks",
131
+ max_concurrent_requests: int = 10,
132
+ ) -> pd.DataFrame:
133
+ """
134
+ Run VLM inference on entire dataframe of images (sync wrapper for async function)
135
+
136
+ Args:
137
+ df: DataFrame containing images
138
+ model: Model to use for inference
139
+ api_key: API key for the provider
140
+ provider: Provider to use (Fireworks or OpenAI)
141
+ max_concurrent_requests: Maximum number of concurrent API requests (default: 10)
142
+
143
+ Returns:
144
+ pd.DataFrame: Results with columns:
145
+ - id: Image ID
146
+ - pred_masterCategory: Predicted master category
147
+ - pred_gender: Predicted gender
148
+ - pred_subCategory: Predicted sub-category
149
+ - pred_description: Predicted description
150
+ """
151
+ return asyncio.run(
152
+ run_inference_on_dataframe_async(df, model, api_key, provider, max_concurrent_requests)
153
+ )
154
 
155
 
156
  def calculate_metrics(
 
296
  )
297
 
298
  return pd.DataFrame(summary_data)
299
+
300
+
301
+ def extract_metrics(results_dict, model_name):
302
+ """
303
+ Extract accuracy, precision, and recall for each category.
304
+
305
+ Args:
306
+ results_dict: Dictionary containing evaluation metrics
307
+ model_name: Name of the model for identification
308
+
309
+ Returns:
310
+ List of dictionaries with metrics per category
311
+ """
312
+ metrics_list = []
313
+
314
+ for category, metrics in results_dict.items():
315
+ metrics_list.append({
316
+ 'model': model_name,
317
+ 'category': category,
318
+ 'accuracy': metrics['accuracy'],
319
+ 'precision': metrics['precision'],
320
+ 'recall': metrics['recall'],
321
+ 'num_samples': metrics['num_samples']
322
+ })
323
+
324
+ return metrics_list
src/modules/vlm_inference.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from openai import OpenAI
3
  from pydantic import BaseModel, Field
4
  from typing import Optional, Literal
5
 
@@ -86,13 +86,12 @@ def analyze_product_image(
86
  Returns:
87
  ProductClassification: Structured classification and description
88
  """
89
- if provider == "Fireworks":
90
- # Initialize OpenAI client
91
  client = OpenAI(
92
  api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
93
  base_url="https://api.fireworks.ai/inference/v1",
94
  )
95
- elif provider == "OpenAI":
96
  client = OpenAI(
97
  api_key=api_key or os.getenv("OPENAI_API_KEY"),
98
  )
@@ -119,6 +118,56 @@ def analyze_product_image(
119
  return completion.choices[0].message.parsed
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def batch_analyze_products(
123
  image_urls: list[str],
124
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
 
1
  import os
2
+ from openai import OpenAI, AsyncOpenAI
3
  from pydantic import BaseModel, Field
4
  from typing import Optional, Literal
5
 
 
86
  Returns:
87
  ProductClassification: Structured classification and description
88
  """
89
+ if provider.lower() in ["fireworks", "fireworksai"]:
 
90
  client = OpenAI(
91
  api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
92
  base_url="https://api.fireworks.ai/inference/v1",
93
  )
94
+ elif provider.lower() == "openai":
95
  client = OpenAI(
96
  api_key=api_key or os.getenv("OPENAI_API_KEY"),
97
  )
 
118
  return completion.choices[0].message.parsed
119
 
120
 
121
+ async def analyze_product_image_async(
122
+ image_url: str,
123
+ model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
124
+ api_key: Optional[str] = None,
125
+ provider: str = "Fireworks",
126
+ ) -> ProductClassification:
127
+ """
128
+ Async version of analyze_product_image for concurrent processing
129
+
130
+ Args:
131
+ image_url: URL or base64-encoded image string (with data:image prefix)
132
+ model: Model to use for inference (default: Qwen2.5 VL 72B)
133
+ api_key: API key (defaults to provider-specific env variable)
134
+ provider: Provider to use for inference (default: Fireworks)
135
+
136
+ Returns:
137
+ ProductClassification: Structured classification and description
138
+ """
139
+ if provider.lower() in ["fireworks", "fireworksai"]:
140
+ client = AsyncOpenAI(
141
+ api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
142
+ base_url="https://api.fireworks.ai/inference/v1",
143
+ )
144
+ elif provider.lower() == "openai":
145
+ client = AsyncOpenAI(
146
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
147
+ )
148
+ else:
149
+ raise ValueError(f"Unknown provider: {provider}")
150
+
151
+ # Call the API with structured output
152
+ completion = await client.beta.chat.completions.parse(
153
+ model=model,
154
+ messages=[
155
+ {"role": "system", "content": SYSTEM_PROMPT},
156
+ {
157
+ "role": "user",
158
+ "content": [
159
+ {"type": "image_url", "image_url": {"url": image_url}},
160
+ {"type": "text", "text": USER_PROMPT},
161
+ ],
162
+ },
163
+ ],
164
+ response_format=ProductClassification,
165
+ )
166
+
167
+ # Extract and return the structured output
168
+ return completion.choices[0].message.parsed
169
+
170
+
171
  def batch_analyze_products(
172
  image_urls: list[str],
173
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",