RobertoBarrosoLuque commited on
Commit
fefbd93
Β·
1 Parent(s): 5515ef5

Update working verion

Browse files
.pre-commit-config.yaml CHANGED
@@ -5,8 +5,6 @@ repos:
5
  - id: trailing-whitespace
6
  - id: end-of-file-fixer
7
  exclude: docs/badges
8
- - id: check-added-large-files
9
- args: ["--maxkb=1024"] # allow up to 1MB
10
  - id: check-json
11
  - id: check-yaml
12
  args: ["--unsafe"] # needed for some mkdocs extensions
 
5
  - id: trailing-whitespace
6
  - id: end-of-file-fixer
7
  exclude: docs/badges
 
 
8
  - id: check-json
9
  - id: check-yaml
10
  args: ["--unsafe"] # needed for some mkdocs extensions
assets/Accuracy-precision-recall.png CHANGED
assets/Accuracy.png CHANGED
notebooks/02-model-evals.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/app.py CHANGED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ from typing import Optional
5
+ import os
6
+ from datasets import load_dataset
7
+ from PIL import Image
8
+ import io
9
+
10
+ from src.modules.vlm_inference import analyze_product_image
11
+ from src.modules.data_processing import pil_to_base64
12
+ from src.modules.evals import run_inference_on_dataframe
13
+
14
+ # Constants
15
+ AVAILABLE_MODELS = {
16
+ "Qwen2.5-VL-32B": "accounts/fireworks/models/qwen2p5-vl-32b-instruct",
17
+ "Llama Maverick": "accounts/fireworks/models/llama4-maverick-instruct-basic",
18
+ "Llama Scout": "accounts/fireworks/models/llama4-scout-instruct-basic",
19
+ }
20
+
21
+ EXAMPLE_IMAGES_DIR = Path("data/examples")
22
+ MAX_CONCURRENT_REQUESTS = 10
23
+
24
+ FILE_PATH = Path(__file__).parents[1]
25
+ ASSETS_PATH = FILE_PATH / "assets"
26
+
27
+
28
+ def analyze_single_image(
29
+ image_input, model_name: str, api_key: Optional[str] = None
30
+ ) -> tuple[str, str, str, str]:
31
+ """
32
+ Process a single product image and return classification results
33
+
34
+ Args:
35
+ image_input: PIL Image or file path
36
+ model_name: Selected model name
37
+ api_key: Optional API key override
38
+
39
+ Returns:
40
+ tuple: (master_category, gender, sub_category, description)
41
+ """
42
+ if image_input is None:
43
+ return "No image provided", "", "", ""
44
+
45
+ try:
46
+ # Convert PIL Image to base64
47
+ img_b64 = pil_to_base64(image_input)
48
+
49
+ # Determine provider from model name
50
+ model_id = AVAILABLE_MODELS[model_name]
51
+
52
+ # Get API key from environment if not provided
53
+ if api_key is None:
54
+ api_key = os.getenv("FIREWORKS_API_KEY")
55
+
56
+ result = analyze_product_image(
57
+ image_url=img_b64, model=model_id, api_key=api_key, provider="Fireworks"
58
+ )
59
+
60
+ # Format results
61
+ master_cat = result.master_category
62
+ gender = result.gender
63
+ sub_cat = result.sub_category
64
+ description = result.description
65
+
66
+ return master_cat, gender, sub_cat, description
67
+
68
+ except Exception as e:
69
+ error_msg = f"Error: {str(e)}"
70
+ return error_msg, "", "", ""
71
+
72
+
73
+ def process_batch_dataset(
74
+ csv_file,
75
+ model_name: str,
76
+ api_key: Optional[str] = None,
77
+ max_concurrent: int = MAX_CONCURRENT_REQUESTS,
78
+ ) -> tuple[pd.DataFrame, str]:
79
+ """
80
+ Process uploaded CSV dataset with product images
81
+
82
+ Args:
83
+ csv_file: Uploaded CSV file with image data
84
+ model_name: Selected model name
85
+ api_key: Optional API key override
86
+ max_concurrent: Max concurrent API requests
87
+
88
+ Returns:
89
+ tuple: (results_dataframe, summary_statistics)
90
+ """
91
+ if csv_file is None:
92
+ return None, "No dataset uploaded"
93
+
94
+ try:
95
+ # Load dataset
96
+ df = pd.read_csv(csv_file.name)
97
+
98
+ # Validate required columns
99
+ required_cols = ["id", "image"]
100
+ if not all(col in df.columns for col in required_cols):
101
+ return None, f"Dataset must contain columns: {required_cols}"
102
+
103
+ # Determine provider
104
+ model_id = AVAILABLE_MODELS[model_name]
105
+
106
+ # Get API key
107
+ if api_key is None:
108
+ api_key = os.getenv("FIREWORKS_API_KEY")
109
+
110
+ # Run batch inference
111
+ results_df = run_inference_on_dataframe(
112
+ df=df,
113
+ model=model_id,
114
+ api_key=api_key,
115
+ provider="Fireworks",
116
+ max_concurrent_requests=max_concurrent,
117
+ )
118
+
119
+ # Generate summary statistics
120
+ total_processed = len(results_df)
121
+ successful = results_df["pred_masterCategory"].notna().sum()
122
+ failed = total_processed - successful
123
+
124
+ summary = f"""
125
+ Batch Processing Complete:
126
+ - Total images: {total_processed}
127
+ - Successfully classified: {successful}
128
+ - Failed: {failed}
129
+ - Success rate: {(successful / total_processed) * 100:.1f}%
130
+ """
131
+
132
+ return results_df, summary
133
+
134
+ except Exception as e:
135
+ return None, f"Error processing dataset: {str(e)}"
136
+
137
+
138
+ def load_example_data() -> pd.DataFrame:
139
+ """Load example product images from HuggingFace dataset"""
140
+ # Load dataset from HuggingFace
141
+ ds = load_dataset("ceyda/fashion-products-small")
142
+ df = ds["train"].to_pandas()
143
+
144
+ # Select 20 random samples
145
+ sample_df = df.sample(n=20, random_state=42).reset_index(drop=True)
146
+
147
+ # Keep only relevant columns for display
148
+ display_df = sample_df[["id", "masterCategory", "gender", "subCategory"]].copy()
149
+ display_df["image_data"] = sample_df["image"]
150
+
151
+ return display_df
152
+
153
+
154
+ def get_image_from_row(examples_df: pd.DataFrame, evt: gr.SelectData) -> Image.Image:
155
+ """Get PIL Image from selected row in examples table"""
156
+ if evt.index is None or len(evt.index) == 0:
157
+ return None
158
+
159
+ row_idx = evt.index[0]
160
+ if row_idx >= len(examples_df):
161
+ return None
162
+
163
+ # Get the image data from the stored row
164
+ image_data = examples_df.iloc[row_idx]["image_data"]
165
+
166
+ # Convert to PIL Image if it's a dict (from HuggingFace datasets)
167
+ if isinstance(image_data, dict):
168
+ if "bytes" in image_data:
169
+ return Image.open(io.BytesIO(image_data["bytes"]))
170
+ elif "path" in image_data:
171
+ return Image.open(image_data["path"])
172
+
173
+ # Return as-is if already a PIL Image
174
+ return image_data
175
+
176
+
177
+ def create_demo_interface():
178
+ """
179
+ Create the Gradio interface with custom theme and layout
180
+ """
181
+ # Load example data at startup
182
+ example_data = load_example_data()
183
+
184
+ with gr.Blocks(
185
+ title="Product Catalog Cleansing",
186
+ theme=gr.themes.Soft(),
187
+ ) as demo:
188
+ # Store examples dataframe in state
189
+ examples_state = gr.State(value=example_data)
190
+
191
+ # Header
192
+ gr.Markdown(
193
+ """
194
+ # Product Catalog Cleansing
195
+
196
+ Automate product classification, attribute extraction, and catalog enrichment
197
+ using state-of-the-art multimodal AI. Fine-tuned SOTA OSS models on FireworksAI.
198
+ """
199
+ )
200
+
201
+ # Model Selection (shared across tabs)
202
+ with gr.Row():
203
+ with gr.Column(scale=1):
204
+ gr.Markdown("### Powered by")
205
+ gr.Image(
206
+ value=str(ASSETS_PATH / "fireworks_logo.png"),
207
+ height=60,
208
+ width=200,
209
+ show_label=False,
210
+ show_download_button=False,
211
+ container=False,
212
+ show_fullscreen_button=False,
213
+ show_share_button=False,
214
+ )
215
+
216
+ model_selector = gr.Dropdown(
217
+ choices=list(AVAILABLE_MODELS.keys()),
218
+ value=list(AVAILABLE_MODELS.keys())[0],
219
+ label="Select Model",
220
+ )
221
+ api_key_input = gr.Textbox(
222
+ label="API Key",
223
+ type="password",
224
+ )
225
+
226
+ with gr.Tabs():
227
+ with gr.TabItem("πŸ“Έ Single Image Analysis"):
228
+ gr.Markdown("### Upload a product image for instant classification")
229
+
230
+ with gr.Row():
231
+ # Left column - Input
232
+ with gr.Column(scale=1):
233
+ image_input = gr.Image(
234
+ label="Upload Product Image", type="pil", height=400
235
+ )
236
+ analyze_btn = gr.Button(
237
+ "πŸ” Analyze Product", variant="primary", size="lg"
238
+ )
239
+
240
+ # Right column - Results
241
+ with gr.Column(scale=1):
242
+ gr.Markdown("### Classification Results")
243
+ master_category_output = gr.Textbox(
244
+ label="Master Category", interactive=False
245
+ )
246
+ gender_output = gr.Textbox(label="Gender", interactive=False)
247
+ subcategory_output = gr.Textbox(
248
+ label="Sub-Category", interactive=False
249
+ )
250
+ description_output = gr.Textbox(
251
+ label="AI-Generated Description", interactive=False, lines=4
252
+ )
253
+
254
+ # Example Products Table
255
+ gr.Markdown("### πŸ“š Example Products (Click a row to load image)")
256
+ examples_table = gr.Dataframe(
257
+ value=example_data[
258
+ ["id", "masterCategory", "gender", "subCategory"]
259
+ ],
260
+ label="Select a product to analyze",
261
+ interactive=False,
262
+ wrap=True,
263
+ )
264
+
265
+ # Wire up single image analysis
266
+ analyze_btn.click(
267
+ fn=analyze_single_image,
268
+ inputs=[image_input, model_selector, api_key_input],
269
+ outputs=[
270
+ master_category_output,
271
+ gender_output,
272
+ subcategory_output,
273
+ description_output,
274
+ ],
275
+ )
276
+
277
+ # Allow clicking table row to load image
278
+ examples_table.select(
279
+ fn=get_image_from_row,
280
+ inputs=[examples_state],
281
+ outputs=[image_input],
282
+ )
283
+
284
+ with gr.Row():
285
+ # Left - Upload
286
+ with gr.Column(scale=1):
287
+ dataset_upload = gr.File(
288
+ label="Upload Dataset (CSV)", file_types=[".csv"]
289
+ )
290
+ concurrent_slider = gr.Slider(
291
+ minimum=1,
292
+ maximum=50,
293
+ value=10,
294
+ step=1,
295
+ label="Concurrent Requests",
296
+ info="Higher = faster but may hit rate limits",
297
+ )
298
+ process_btn = gr.Button(
299
+ "⚑ Process Dataset", variant="primary", size="lg"
300
+ )
301
+
302
+ # Right - Results summary
303
+ with gr.Column(scale=1):
304
+ summary_output = gr.Textbox(
305
+ label="Processing Summary", interactive=False, lines=8
306
+ )
307
+
308
+ # Results dataframe
309
+ results_dataframe = gr.Dataframe(
310
+ label="Classification Results", interactive=False, wrap=True
311
+ )
312
+
313
+ # Wire up batch processing
314
+ process_btn.click(
315
+ fn=process_batch_dataset,
316
+ inputs=[
317
+ dataset_upload,
318
+ model_selector,
319
+ api_key_input,
320
+ concurrent_slider,
321
+ ],
322
+ outputs=[results_dataframe, summary_output],
323
+ )
324
+
325
+ # Tab 3: Model Evaluation (show uploaded charts)
326
+ with gr.TabItem("πŸ“ˆ Model Performance"):
327
+ gr.Markdown(
328
+ """
329
+ ### Evaluation Results on Fashion Product Dataset
330
+
331
+ Model fine tuned on over 14k images and tested on a validation set of 1000 images.
332
+
333
+ Images pulled from [HuggingFace Datasets](https://huggingface.co/datasets/ceyda/fashion-products-small)
334
+ """
335
+ )
336
+
337
+ # Display uploaded evaluation charts
338
+ with gr.Row():
339
+ gr.Image(
340
+ value=str(ASSETS_PATH / "Accuracy.png"),
341
+ interactive=False,
342
+ show_label=False,
343
+ )
344
+ gr.Image(
345
+ value=str(ASSETS_PATH / "Accuracy-precision-recall.png"),
346
+ interactive=False,
347
+ show_label=False,
348
+ )
349
+
350
+ gr.Markdown(
351
+ """
352
+ **Key Findings:**
353
+ - Qwen2.5-VL-72B-SFT achieves >95% accuracy on masterCategory
354
+ - Fine-tuned model shows 18% improvement on subCategory vs base model
355
+ - All models maintain >90% precision and recall on gender classification
356
+ """
357
+ )
358
+
359
+ return demo
360
+
361
+
362
+ if __name__ == "__main__":
363
+ # Launch demo
364
+ demo = create_demo_interface()
365
+ demo.launch(
366
+ server_name="0.0.0.0",
367
+ server_port=7860,
368
+ share=False,
369
+ show_error=True,
370
+ )