d82b7a7c1d
gguf_new_metadata.py reads data from reader. Reader doesn't byteswap tensors to native endianness. But writer does expect tensors in native endianness to convert them into requested endianness. There are two ways to fix this: update reader and do conversion to native endianness and back, or skip converting endianness in writer in this particular USE-case. gguf_editor_gui.py doesn't allow editing or viewing tensor data. Let's go with skipping excessive byteswapping. If eventually capability to view or edit tensor data is added, tensor data should be instead byteswapped when reading it.
1622 lines
63 KiB
Python
Executable File
1622 lines
63 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import numpy
|
||
import enum
|
||
from pathlib import Path
|
||
from typing import Any, Optional, Tuple, Type
|
||
import warnings
|
||
|
||
import numpy as np
|
||
from PySide6.QtWidgets import (
|
||
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
||
QPushButton, QLabel, QLineEdit, QFileDialog, QTableWidget,
|
||
QTableWidgetItem, QComboBox, QMessageBox, QTabWidget,
|
||
QTextEdit, QFormLayout,
|
||
QHeaderView, QDialog, QDialogButtonBox
|
||
)
|
||
from PySide6.QtCore import Qt
|
||
|
||
# Necessary to load the local gguf package
|
||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||
|
||
import gguf
|
||
from gguf import GGUFReader, GGUFWriter, GGUFValueType, ReaderField
|
||
from gguf.constants import TokenType, RopeScalingType, PoolingType, GGMLQuantizationType
|
||
|
||
logger = logging.getLogger("gguf-editor-gui")
|
||
|
||
# Map of key names to enum types for automatic enum interpretation
|
||
KEY_TO_ENUM_TYPE = {
|
||
gguf.Keys.Tokenizer.TOKEN_TYPE: TokenType,
|
||
gguf.Keys.Rope.SCALING_TYPE: RopeScalingType,
|
||
gguf.Keys.LLM.POOLING_TYPE: PoolingType,
|
||
gguf.Keys.General.FILE_TYPE: GGMLQuantizationType,
|
||
}
|
||
|
||
# Define the tokenizer keys that should be edited together
|
||
TOKENIZER_LINKED_KEYS = [
|
||
gguf.Keys.Tokenizer.LIST,
|
||
gguf.Keys.Tokenizer.TOKEN_TYPE,
|
||
gguf.Keys.Tokenizer.SCORES
|
||
]
|
||
|
||
|
||
class TokenizerEditorDialog(QDialog):
|
||
def __init__(self, tokens, token_types, scores, parent=None):
|
||
super().__init__(parent)
|
||
self.setWindowTitle("Edit Tokenizer Data")
|
||
self.resize(900, 600)
|
||
|
||
self.tokens = tokens.copy() if tokens else []
|
||
self.token_types = token_types.copy() if token_types else []
|
||
self.scores = scores.copy() if scores else []
|
||
|
||
# Ensure all arrays have the same length
|
||
max_len = max(len(self.tokens), len(self.token_types), len(self.scores))
|
||
if len(self.tokens) < max_len:
|
||
self.tokens.extend([""] * (max_len - len(self.tokens)))
|
||
if len(self.token_types) < max_len:
|
||
self.token_types.extend([0] * (max_len - len(self.token_types)))
|
||
if len(self.scores) < max_len:
|
||
self.scores.extend([0.0] * (max_len - len(self.scores)))
|
||
|
||
layout = QVBoxLayout(self)
|
||
|
||
# Add filter controls
|
||
filter_layout = QHBoxLayout()
|
||
filter_layout.addWidget(QLabel("Filter:"))
|
||
self.filter_edit = QLineEdit()
|
||
self.filter_edit.setPlaceholderText("Type to filter tokens...")
|
||
self.filter_edit.textChanged.connect(self.apply_filter)
|
||
filter_layout.addWidget(self.filter_edit)
|
||
|
||
# Add page controls
|
||
self.page_size = 100 # Show 100 items per page
|
||
self.current_page = 0
|
||
self.total_pages = max(1, (len(self.tokens) + self.page_size - 1) // self.page_size)
|
||
|
||
self.page_label = QLabel(f"Page 1 of {self.total_pages}")
|
||
filter_layout.addWidget(self.page_label)
|
||
|
||
prev_page = QPushButton("Previous")
|
||
prev_page.clicked.connect(self.previous_page)
|
||
filter_layout.addWidget(prev_page)
|
||
|
||
next_page = QPushButton("Next")
|
||
next_page.clicked.connect(self.next_page)
|
||
filter_layout.addWidget(next_page)
|
||
|
||
layout.addLayout(filter_layout)
|
||
|
||
# Tokenizer data table
|
||
self.tokens_table = QTableWidget()
|
||
self.tokens_table.setColumnCount(4)
|
||
self.tokens_table.setHorizontalHeaderLabels(["Index", "Token", "Type", "Score"])
|
||
self.tokens_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.tokens_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
|
||
self.tokens_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.tokens_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
|
||
|
||
layout.addWidget(self.tokens_table)
|
||
|
||
# Controls
|
||
controls_layout = QHBoxLayout()
|
||
|
||
add_button = QPushButton("Add Token")
|
||
add_button.clicked.connect(self.add_token)
|
||
controls_layout.addWidget(add_button)
|
||
|
||
remove_button = QPushButton("Remove Selected")
|
||
remove_button.clicked.connect(self.remove_selected)
|
||
controls_layout.addWidget(remove_button)
|
||
|
||
controls_layout.addStretch()
|
||
|
||
layout.addLayout(controls_layout)
|
||
|
||
# Buttons
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(self.accept)
|
||
buttons.rejected.connect(self.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
# Initialize the filtered values
|
||
self.filtered_indices = list(range(len(self.tokens)))
|
||
|
||
# Load data for the first page
|
||
self.load_page()
|
||
|
||
def apply_filter(self):
|
||
"""Filter the tokens based on the search text."""
|
||
filter_text = self.filter_edit.text().lower()
|
||
|
||
if not filter_text:
|
||
# No filter, show all values
|
||
self.filtered_indices = list(range(len(self.tokens)))
|
||
else:
|
||
# Apply filter
|
||
self.filtered_indices = []
|
||
for i, token in enumerate(self.tokens):
|
||
if filter_text in str(token).lower():
|
||
self.filtered_indices.append(i)
|
||
|
||
# Reset to first page and reload
|
||
self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
|
||
self.current_page = 0
|
||
self.page_label.setText(f"Page 1 of {self.total_pages}")
|
||
self.load_page()
|
||
|
||
def previous_page(self):
|
||
"""Go to the previous page of results."""
|
||
if self.current_page > 0:
|
||
self.current_page -= 1
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
self.load_page()
|
||
|
||
def next_page(self):
|
||
"""Go to the next page of results."""
|
||
if self.current_page < self.total_pages - 1:
|
||
self.current_page += 1
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
self.load_page()
|
||
|
||
def load_page(self):
|
||
"""Load the current page of tokenizer data."""
|
||
self.tokens_table.setRowCount(0) # Clear the table
|
||
|
||
# Calculate start and end indices for the current page
|
||
start_idx = self.current_page * self.page_size
|
||
end_idx = min(start_idx + self.page_size, len(self.filtered_indices))
|
||
|
||
# Pre-allocate rows for better performance
|
||
self.tokens_table.setRowCount(end_idx - start_idx)
|
||
|
||
for row, i in enumerate(range(start_idx, end_idx)):
|
||
orig_idx = self.filtered_indices[i]
|
||
|
||
# Index
|
||
index_item = QTableWidgetItem(str(orig_idx))
|
||
index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index
|
||
index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tokens_table.setItem(row, 0, index_item)
|
||
|
||
# Token
|
||
token_item = QTableWidgetItem(str(self.tokens[orig_idx]))
|
||
self.tokens_table.setItem(row, 1, token_item)
|
||
|
||
# Token Type
|
||
token_type = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0
|
||
try:
|
||
enum_val = TokenType(token_type)
|
||
display_text = f"{enum_val.name} ({token_type})"
|
||
except (ValueError, KeyError):
|
||
display_text = f"Unknown ({token_type})"
|
||
|
||
type_item = QTableWidgetItem(display_text)
|
||
type_item.setData(Qt.ItemDataRole.UserRole, token_type)
|
||
|
||
# Make type cell editable with a double-click handler
|
||
type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tokens_table.setItem(row, 2, type_item)
|
||
|
||
# Score
|
||
score = self.scores[orig_idx] if orig_idx < len(self.scores) else 0.0
|
||
score_item = QTableWidgetItem(str(score))
|
||
self.tokens_table.setItem(row, 3, score_item)
|
||
|
||
# Connect double-click handler for token type cells
|
||
self.tokens_table.cellDoubleClicked.connect(self.handle_cell_double_click)
|
||
|
||
def handle_cell_double_click(self, row, column):
|
||
"""Handle double-click on a cell, specifically for token type editing."""
|
||
if column == 2: # Token Type column
|
||
orig_item = self.tokens_table.item(row, 0)
|
||
if orig_item:
|
||
orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
|
||
self.edit_token_type(row, orig_idx)
|
||
|
||
def edit_token_type(self, row, orig_idx):
|
||
"""Edit a token type using a dialog with a dropdown of all enum options."""
|
||
current_value = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0
|
||
|
||
# Create a dialog with enum options
|
||
dialog = QDialog(self)
|
||
dialog.setWindowTitle("Select Token Type")
|
||
layout = QVBoxLayout(dialog)
|
||
|
||
combo = QComboBox()
|
||
for enum_val in TokenType:
|
||
combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
|
||
|
||
# Set current value
|
||
try:
|
||
if isinstance(current_value, int):
|
||
enum_val = TokenType(current_value)
|
||
combo.setCurrentText(f"{enum_val.name} ({current_value})")
|
||
except (ValueError, KeyError):
|
||
pass
|
||
|
||
layout.addWidget(combo)
|
||
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(dialog.accept)
|
||
buttons.rejected.connect(dialog.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
# Get the selected value
|
||
new_value = combo.currentData()
|
||
enum_val = TokenType(new_value)
|
||
display_text = f"{enum_val.name} ({new_value})"
|
||
|
||
# Update the display
|
||
type_item = self.tokens_table.item(row, 2)
|
||
if type_item:
|
||
type_item.setText(display_text)
|
||
type_item.setData(Qt.ItemDataRole.UserRole, new_value)
|
||
|
||
# Update the actual value
|
||
self.token_types[orig_idx] = new_value
|
||
|
||
def add_token(self):
|
||
"""Add a new token to the end of the list."""
|
||
# Add to the end of the arrays
|
||
self.tokens.append("")
|
||
self.token_types.append(0) # Default to normal token
|
||
self.scores.append(0.0)
|
||
|
||
orig_idx = len(self.tokens) - 1
|
||
|
||
# Add to filtered indices if it matches the current filter
|
||
filter_text = self.filter_edit.text().lower()
|
||
if not filter_text or filter_text in "":
|
||
self.filtered_indices.append(orig_idx)
|
||
|
||
# Update pagination
|
||
self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
|
||
|
||
# Go to the last page to show the new item
|
||
self.current_page = self.total_pages - 1
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
|
||
# Reload the page
|
||
self.load_page()
|
||
|
||
def remove_selected(self):
|
||
"""Remove selected tokens from all arrays."""
|
||
selected_rows = []
|
||
for item in self.tokens_table.selectedItems():
|
||
row = item.row()
|
||
if row not in selected_rows:
|
||
selected_rows.append(row)
|
||
|
||
if not selected_rows:
|
||
return
|
||
|
||
# Get original indices in descending order to avoid index shifting
|
||
orig_indices = []
|
||
for row in selected_rows:
|
||
orig_item = self.tokens_table.item(row, 0)
|
||
if orig_item:
|
||
orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole))
|
||
orig_indices.sort(reverse=True)
|
||
|
||
# Remove from all arrays
|
||
for idx in orig_indices:
|
||
if idx < len(self.tokens):
|
||
del self.tokens[idx]
|
||
if idx < len(self.token_types):
|
||
del self.token_types[idx]
|
||
if idx < len(self.scores):
|
||
del self.scores[idx]
|
||
|
||
# Rebuild filtered_indices
|
||
self.filtered_indices = []
|
||
filter_text = self.filter_edit.text().lower()
|
||
|
||
for i, token in enumerate(self.tokens):
|
||
if not filter_text or filter_text in str(token).lower():
|
||
self.filtered_indices.append(i)
|
||
|
||
# Update pagination
|
||
self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
|
||
self.current_page = min(self.current_page, self.total_pages - 1)
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
|
||
# Reload the page
|
||
self.load_page()
|
||
|
||
def get_data(self):
|
||
"""Return the edited tokenizer data."""
|
||
return self.tokens, self.token_types, self.scores
|
||
|
||
|
||
class ArrayEditorDialog(QDialog):
|
||
def __init__(self, array_values, element_type, key=None, parent=None):
|
||
super().__init__(parent)
|
||
self.setWindowTitle("Edit Array Values")
|
||
self.resize(700, 500)
|
||
|
||
self.array_values = array_values
|
||
self.element_type = element_type
|
||
self.key = key
|
||
|
||
# Get enum type for this array if applicable
|
||
self.enum_type = None
|
||
if key in KEY_TO_ENUM_TYPE and element_type == GGUFValueType.INT32:
|
||
self.enum_type = KEY_TO_ENUM_TYPE[key]
|
||
|
||
layout = QVBoxLayout(self)
|
||
|
||
# Add enum type information if applicable
|
||
if self.enum_type is not None:
|
||
enum_info_layout = QHBoxLayout()
|
||
enum_label = QLabel(f"Editing {self.enum_type.__name__} values:")
|
||
enum_info_layout.addWidget(enum_label)
|
||
|
||
# Add a legend for the enum values
|
||
enum_values = ", ".join([f"{e.name}={e.value}" for e in self.enum_type])
|
||
enum_values_label = QLabel(f"Available values: {enum_values}")
|
||
enum_values_label.setWordWrap(True)
|
||
enum_info_layout.addWidget(enum_values_label, 1)
|
||
|
||
layout.addLayout(enum_info_layout)
|
||
|
||
# Add search/filter controls
|
||
filter_layout = QHBoxLayout()
|
||
filter_layout.addWidget(QLabel("Filter:"))
|
||
self.filter_edit = QLineEdit()
|
||
self.filter_edit.setPlaceholderText("Type to filter values...")
|
||
self.filter_edit.textChanged.connect(self.apply_filter)
|
||
filter_layout.addWidget(self.filter_edit)
|
||
|
||
# Add page controls for large arrays
|
||
self.page_size = 100 # Show 100 items per page
|
||
self.current_page = 0
|
||
self.total_pages = max(1, (len(array_values) + self.page_size - 1) // self.page_size)
|
||
|
||
self.page_label = QLabel(f"Page 1 of {self.total_pages}")
|
||
filter_layout.addWidget(self.page_label)
|
||
|
||
prev_page = QPushButton("Previous")
|
||
prev_page.clicked.connect(self.previous_page)
|
||
filter_layout.addWidget(prev_page)
|
||
|
||
next_page = QPushButton("Next")
|
||
next_page.clicked.connect(self.next_page)
|
||
filter_layout.addWidget(next_page)
|
||
|
||
layout.addLayout(filter_layout)
|
||
|
||
# Array items table
|
||
self.items_table = QTableWidget()
|
||
|
||
# Set up columns based on whether we have an enum type
|
||
if self.enum_type is not None:
|
||
self.items_table.setColumnCount(3)
|
||
self.items_table.setHorizontalHeaderLabels(["Index", "Value", "Actions"])
|
||
self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
|
||
self.items_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
|
||
else:
|
||
self.items_table.setColumnCount(2)
|
||
self.items_table.setHorizontalHeaderLabels(["Index", "Value"])
|
||
self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
|
||
|
||
layout.addWidget(self.items_table)
|
||
|
||
# Controls
|
||
controls_layout = QHBoxLayout()
|
||
|
||
add_button = QPushButton("Add Item")
|
||
add_button.clicked.connect(self.add_item)
|
||
controls_layout.addWidget(add_button)
|
||
|
||
remove_button = QPushButton("Remove Selected")
|
||
remove_button.clicked.connect(self.remove_selected)
|
||
controls_layout.addWidget(remove_button)
|
||
|
||
# Add bulk edit button for enum arrays
|
||
if self.enum_type is not None:
|
||
bulk_edit_button = QPushButton("Bulk Edit Selected")
|
||
bulk_edit_button.clicked.connect(self.bulk_edit_selected)
|
||
controls_layout.addWidget(bulk_edit_button)
|
||
|
||
controls_layout.addStretch()
|
||
|
||
layout.addLayout(controls_layout)
|
||
|
||
# Buttons
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(self.accept)
|
||
buttons.rejected.connect(self.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
# Initialize the filtered values
|
||
self.filtered_indices = list(range(len(self.array_values)))
|
||
|
||
# Load array values for the first page
|
||
self.load_page()
|
||
|
||
def apply_filter(self):
|
||
"""Filter the array values based on the search text."""
|
||
filter_text = self.filter_edit.text().lower()
|
||
|
||
if not filter_text:
|
||
# No filter, show all values
|
||
self.filtered_indices = list(range(len(self.array_values)))
|
||
else:
|
||
# Apply filter
|
||
self.filtered_indices = []
|
||
for i, value in enumerate(self.array_values):
|
||
# For enum values, search in both name and value
|
||
if self.enum_type is not None and isinstance(value, int):
|
||
try:
|
||
enum_val = self.enum_type(value)
|
||
display_text = f"{enum_val.name} ({value})".lower()
|
||
if filter_text in display_text:
|
||
self.filtered_indices.append(i)
|
||
except (ValueError, KeyError):
|
||
# If not a valid enum value, just check the raw value
|
||
if filter_text in str(value).lower():
|
||
self.filtered_indices.append(i)
|
||
else:
|
||
# For non-enum values, just check the string representation
|
||
if filter_text in str(value).lower():
|
||
self.filtered_indices.append(i)
|
||
|
||
# Reset to first page and reload
|
||
self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
|
||
self.current_page = 0
|
||
self.page_label.setText(f"Page 1 of {self.total_pages}")
|
||
self.load_page()
|
||
|
||
def previous_page(self):
|
||
"""Go to the previous page of results."""
|
||
if self.current_page > 0:
|
||
self.current_page -= 1
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
self.load_page()
|
||
|
||
def next_page(self):
|
||
"""Go to the next page of results."""
|
||
if self.current_page < self.total_pages - 1:
|
||
self.current_page += 1
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
self.load_page()
|
||
|
||
def load_page(self):
|
||
"""Load the current page of array values."""
|
||
self.items_table.setRowCount(0) # Clear the table
|
||
|
||
# Calculate start and end indices for the current page
|
||
start_idx = self.current_page * self.page_size
|
||
end_idx = min(start_idx + self.page_size, len(self.filtered_indices))
|
||
|
||
# Pre-allocate rows for better performance
|
||
self.items_table.setRowCount(end_idx - start_idx)
|
||
|
||
for row, i in enumerate(range(start_idx, end_idx)):
|
||
orig_idx = self.filtered_indices[i]
|
||
value = self.array_values[orig_idx]
|
||
|
||
# Index
|
||
index_item = QTableWidgetItem(str(orig_idx))
|
||
index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index
|
||
index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.items_table.setItem(row, 0, index_item)
|
||
|
||
# Value
|
||
if self.enum_type is not None:
|
||
# Display enum value and name
|
||
try:
|
||
if isinstance(value, (int, numpy.signedinteger)):
|
||
enum_val = self.enum_type(value)
|
||
display_text = f"{enum_val.name} ({value})"
|
||
else:
|
||
display_text = str(value)
|
||
except (ValueError, KeyError):
|
||
display_text = f"Unknown ({value})"
|
||
|
||
# Store the enum value in the item
|
||
value_item = QTableWidgetItem(display_text)
|
||
value_item.setData(Qt.ItemDataRole.UserRole, value)
|
||
value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.items_table.setItem(row, 1, value_item)
|
||
|
||
# Add an edit button in a separate column
|
||
edit_button = QPushButton("Edit")
|
||
edit_button.setProperty("row", row)
|
||
edit_button.clicked.connect(self.edit_array_enum_value)
|
||
|
||
# Create a widget to hold the button
|
||
button_widget = QWidget()
|
||
button_layout = QHBoxLayout(button_widget)
|
||
button_layout.setContentsMargins(2, 2, 2, 2)
|
||
button_layout.addWidget(edit_button)
|
||
button_layout.addStretch()
|
||
|
||
self.items_table.setCellWidget(row, 2, button_widget)
|
||
else:
|
||
value_item = QTableWidgetItem(str(value))
|
||
self.items_table.setItem(row, 1, value_item)
|
||
|
||
def edit_array_enum_value(self):
|
||
"""Handle editing an enum value in the array editor."""
|
||
button = self.sender()
|
||
row = button.property("row")
|
||
|
||
# Get the original index from the table item
|
||
orig_item = self.items_table.item(row, 0)
|
||
new_item = self.items_table.item(row, 1)
|
||
if orig_item and new_item and self.enum_type and self.edit_enum_value(row, self.enum_type):
|
||
orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
|
||
new_value = new_item.data(Qt.ItemDataRole.UserRole)
|
||
# Update the stored value in the array
|
||
if isinstance(new_value, (int, float, str, bool)):
|
||
self.array_values[orig_idx] = new_value
|
||
|
||
def bulk_edit_selected(self):
|
||
"""Edit multiple enum values at once."""
|
||
if not self.enum_type:
|
||
return
|
||
|
||
selected_rows = set()
|
||
for item in self.items_table.selectedItems():
|
||
selected_rows.add(item.row())
|
||
|
||
if not selected_rows:
|
||
QMessageBox.information(self, "No Selection", "Please select at least one row to edit.")
|
||
return
|
||
|
||
# Create a dialog with enum options
|
||
dialog = QDialog(self)
|
||
dialog.setWindowTitle(f"Bulk Edit {self.enum_type.__name__} Values")
|
||
layout = QVBoxLayout(dialog)
|
||
|
||
layout.addWidget(QLabel(f"Set {len(selected_rows)} selected items to:"))
|
||
|
||
combo = QComboBox()
|
||
for enum_val in self.enum_type:
|
||
combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
|
||
|
||
layout.addWidget(combo)
|
||
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(dialog.accept)
|
||
buttons.rejected.connect(dialog.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
# Get the selected value
|
||
new_value = combo.currentData()
|
||
enum_val = self.enum_type(new_value)
|
||
display_text = f"{enum_val.name} ({new_value})"
|
||
|
||
# Update all selected rows
|
||
for row in selected_rows:
|
||
orig_item = self.items_table.item(row, 0)
|
||
new_item = self.items_table.item(row, 1)
|
||
if orig_item and new_item:
|
||
orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
|
||
self.array_values[orig_idx] = new_value
|
||
|
||
# Update the display
|
||
new_item.setText(display_text)
|
||
new_item.setData(Qt.ItemDataRole.UserRole, new_value)
|
||
|
||
def add_item(self):
|
||
# Add to the end of the array
|
||
orig_idx = len(self.array_values)
|
||
|
||
# Add default value based on type
|
||
if self.enum_type is not None:
|
||
# Default to first enum value
|
||
default_value = list(self.enum_type)[0].value
|
||
self.array_values.append(default_value)
|
||
else:
|
||
if self.element_type == GGUFValueType.STRING:
|
||
self.array_values.append("")
|
||
else:
|
||
self.array_values.append(0)
|
||
|
||
# Add to filtered indices if it matches the current filter
|
||
self.filtered_indices.append(orig_idx)
|
||
|
||
# Update pagination
|
||
self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
|
||
|
||
# Go to the last page to show the new item
|
||
self.current_page = self.total_pages - 1
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
|
||
# Reload the page
|
||
self.load_page()
|
||
|
||
def remove_selected(self):
|
||
selected_rows = []
|
||
for item in self.items_table.selectedItems():
|
||
row = item.row()
|
||
if row not in selected_rows:
|
||
selected_rows.append(row)
|
||
|
||
if not selected_rows:
|
||
return
|
||
|
||
# Get original indices in descending order to avoid index shifting
|
||
orig_indices = list()
|
||
for row in selected_rows:
|
||
orig_item = self.items_table.item(row, 0)
|
||
if orig_item:
|
||
orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole))
|
||
orig_indices.sort(reverse=True)
|
||
|
||
# Remove from array_values
|
||
for idx in orig_indices:
|
||
del self.array_values[idx]
|
||
|
||
# Rebuild filtered_indices
|
||
self.filtered_indices = []
|
||
filter_text = self.filter_edit.text().lower()
|
||
|
||
for i, value in enumerate(self.array_values):
|
||
if not filter_text:
|
||
self.filtered_indices.append(i)
|
||
else:
|
||
# Apply filter
|
||
if self.enum_type is not None and isinstance(value, int):
|
||
try:
|
||
enum_val = self.enum_type(value)
|
||
display_text = f"{enum_val.name} ({value})".lower()
|
||
if filter_text in display_text:
|
||
self.filtered_indices.append(i)
|
||
except (ValueError, KeyError):
|
||
if filter_text in str(value).lower():
|
||
self.filtered_indices.append(i)
|
||
else:
|
||
if filter_text in str(value).lower():
|
||
self.filtered_indices.append(i)
|
||
|
||
# Update pagination
|
||
self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
|
||
self.current_page = min(self.current_page, self.total_pages - 1)
|
||
self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
|
||
|
||
# Reload the page
|
||
self.load_page()
|
||
|
||
def edit_enum_value(self, row: int, enum_type: Type[enum.Enum]):
|
||
"""Edit an enum value using a dialog with a dropdown of all enum options."""
|
||
# Get the original index from the table item
|
||
orig_item = self.items_table.item(row, 0)
|
||
if orig_item:
|
||
orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
|
||
else:
|
||
return
|
||
current_value = self.array_values[orig_idx]
|
||
|
||
# Create a dialog with enum options
|
||
dialog = QDialog(self)
|
||
dialog.setWindowTitle(f"Select {enum_type.__name__} Value")
|
||
layout = QVBoxLayout(dialog)
|
||
|
||
# Add description
|
||
description = QLabel(f"Select a {enum_type.__name__} value:")
|
||
layout.addWidget(description)
|
||
|
||
# Use a combo box for quick selection
|
||
combo = QComboBox()
|
||
for enum_val in enum_type:
|
||
combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
|
||
|
||
# Set current value
|
||
try:
|
||
if isinstance(current_value, int):
|
||
enum_val = enum_type(current_value)
|
||
combo.setCurrentText(f"{enum_val.name} ({current_value})")
|
||
except (ValueError, KeyError):
|
||
pass
|
||
|
||
layout.addWidget(combo)
|
||
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(dialog.accept)
|
||
buttons.rejected.connect(dialog.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
# Update the value display and stored data
|
||
new_value = combo.currentData()
|
||
enum_val = enum_type(new_value)
|
||
display_text = f"{enum_val.name} ({new_value})"
|
||
|
||
new_item = self.items_table.item(row, 1)
|
||
if new_item:
|
||
new_item.setText(display_text)
|
||
new_item.setData(Qt.ItemDataRole.UserRole, new_value)
|
||
|
||
# Update the actual array value
|
||
self.array_values[orig_idx] = new_value
|
||
return True
|
||
return False
|
||
|
||
def get_array_values(self):
|
||
# The array_values list is kept up-to-date as edits are made
|
||
return self.array_values
|
||
|
||
|
||
class AddMetadataDialog(QDialog):
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent)
|
||
self.setWindowTitle("Add Metadata")
|
||
self.resize(400, 200)
|
||
|
||
layout = QVBoxLayout(self)
|
||
|
||
form_layout = QFormLayout()
|
||
|
||
self.key_edit = QLineEdit()
|
||
form_layout.addRow("Key:", self.key_edit)
|
||
|
||
self.type_combo = QComboBox()
|
||
for value_type in GGUFValueType:
|
||
if value_type != GGUFValueType.ARRAY: # Skip array type for simplicity
|
||
self.type_combo.addItem(value_type.name, value_type)
|
||
form_layout.addRow("Type:", self.type_combo)
|
||
|
||
self.value_edit = QTextEdit()
|
||
form_layout.addRow("Value:", self.value_edit)
|
||
|
||
layout.addLayout(form_layout)
|
||
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(self.accept)
|
||
buttons.rejected.connect(self.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
def get_data(self) -> Tuple[str, GGUFValueType, Any]:
|
||
key = self.key_edit.text()
|
||
value_type = self.type_combo.currentData()
|
||
value_text = self.value_edit.toPlainText()
|
||
|
||
# Convert value based on type
|
||
if value_type == GGUFValueType.UINT8:
|
||
value = np.uint8(int(value_text))
|
||
elif value_type == GGUFValueType.INT8:
|
||
value = np.int8(int(value_text))
|
||
elif value_type == GGUFValueType.UINT16:
|
||
value = np.uint16(int(value_text))
|
||
elif value_type == GGUFValueType.INT16:
|
||
value = np.int16(int(value_text))
|
||
elif value_type == GGUFValueType.UINT32:
|
||
value = np.uint32(int(value_text))
|
||
elif value_type == GGUFValueType.INT32:
|
||
value = np.int32(int(value_text))
|
||
elif value_type == GGUFValueType.FLOAT32:
|
||
value = np.float32(float(value_text))
|
||
elif value_type == GGUFValueType.BOOL:
|
||
value = value_text.lower() in ('true', 'yes', '1')
|
||
elif value_type == GGUFValueType.STRING:
|
||
value = value_text
|
||
else:
|
||
value = value_text
|
||
|
||
return key, value_type, value
|
||
|
||
|
||
class GGUFEditorWindow(QMainWindow):
|
||
def __init__(self):
|
||
super().__init__()
|
||
|
||
self.setWindowTitle("GGUF Editor")
|
||
self.resize(1000, 800)
|
||
|
||
self.current_file = None
|
||
self.reader = None
|
||
self.modified = False
|
||
self.metadata_changes = {} # Store changes to apply when saving
|
||
self.metadata_to_remove = set() # Store keys to remove when saving
|
||
self.on_metadata_changed_is_connected = False
|
||
|
||
self.setup_ui()
|
||
|
||
def setup_ui(self):
|
||
central_widget = QWidget()
|
||
self.setCentralWidget(central_widget)
|
||
|
||
main_layout = QVBoxLayout(central_widget)
|
||
|
||
# File controls
|
||
file_layout = QHBoxLayout()
|
||
|
||
self.file_path_edit = QLineEdit()
|
||
self.file_path_edit.setReadOnly(True)
|
||
file_layout.addWidget(self.file_path_edit)
|
||
|
||
open_button = QPushButton("Open GGUF")
|
||
open_button.clicked.connect(self.open_file)
|
||
file_layout.addWidget(open_button)
|
||
|
||
save_button = QPushButton("Save As...")
|
||
save_button.clicked.connect(self.save_file)
|
||
file_layout.addWidget(save_button)
|
||
|
||
main_layout.addLayout(file_layout)
|
||
|
||
# Tabs for different views
|
||
self.tabs = QTabWidget()
|
||
|
||
# Metadata tab
|
||
self.metadata_tab = QWidget()
|
||
metadata_layout = QVBoxLayout(self.metadata_tab)
|
||
|
||
# Metadata table
|
||
self.metadata_table = QTableWidget()
|
||
self.metadata_table.setColumnCount(4)
|
||
self.metadata_table.setHorizontalHeaderLabels(["Key", "Type", "Value", "Actions"])
|
||
self.metadata_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
||
self.metadata_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.metadata_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch)
|
||
self.metadata_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
|
||
metadata_layout.addWidget(self.metadata_table)
|
||
|
||
# Metadata controls
|
||
metadata_controls = QHBoxLayout()
|
||
|
||
add_metadata_button = QPushButton("Add Metadata")
|
||
add_metadata_button.clicked.connect(self.add_metadata)
|
||
metadata_controls.addWidget(add_metadata_button)
|
||
|
||
metadata_controls.addStretch()
|
||
|
||
metadata_layout.addLayout(metadata_controls)
|
||
|
||
# Tensors tab
|
||
self.tensors_tab = QWidget()
|
||
tensors_layout = QVBoxLayout(self.tensors_tab)
|
||
|
||
self.tensors_table = QTableWidget()
|
||
self.tensors_table.setColumnCount(5)
|
||
self.tensors_table.setHorizontalHeaderLabels(["Name", "Type", "Shape", "Elements", "Size (bytes)"])
|
||
self.tensors_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
||
self.tensors_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.tensors_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.tensors_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
|
||
self.tensors_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeMode.ResizeToContents)
|
||
tensors_layout.addWidget(self.tensors_table)
|
||
|
||
# Add tabs to tab widget
|
||
self.tabs.addTab(self.metadata_tab, "Metadata")
|
||
self.tabs.addTab(self.tensors_tab, "Tensors")
|
||
|
||
main_layout.addWidget(self.tabs)
|
||
|
||
# Status bar
|
||
self.statusBar().showMessage("Ready")
|
||
|
||
def load_file(self, file_path):
|
||
"""Load a GGUF file by path"""
|
||
try:
|
||
self.statusBar().showMessage(f"Loading {file_path}...")
|
||
QApplication.processEvents()
|
||
|
||
self.reader = GGUFReader(file_path, 'r')
|
||
self.current_file = file_path
|
||
self.file_path_edit.setText(file_path)
|
||
|
||
self.load_metadata()
|
||
self.load_tensors()
|
||
|
||
self.metadata_changes = {}
|
||
self.metadata_to_remove = set()
|
||
self.modified = False
|
||
|
||
self.statusBar().showMessage(f"Loaded {file_path}")
|
||
return True
|
||
except Exception as e:
|
||
QMessageBox.critical(self, "Error", f"Failed to open file: {str(e)}")
|
||
self.statusBar().showMessage("Error loading file")
|
||
return False
|
||
|
||
def open_file(self):
|
||
file_path, _ = QFileDialog.getOpenFileName(
|
||
self, "Open GGUF File", "", "GGUF Files (*.gguf);;All Files (*)"
|
||
)
|
||
|
||
if not file_path:
|
||
return
|
||
|
||
self.load_file(file_path)
|
||
|
||
def load_metadata(self):
|
||
self.metadata_table.setRowCount(0)
|
||
|
||
if not self.reader:
|
||
return
|
||
|
||
# Disconnect to prevent triggering during loading
|
||
if self.on_metadata_changed_is_connected:
|
||
with warnings.catch_warnings():
|
||
warnings.filterwarnings('ignore')
|
||
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
||
self.on_metadata_changed_is_connected = False
|
||
|
||
for i, (key, field) in enumerate(self.reader.fields.items()):
|
||
self.metadata_table.insertRow(i)
|
||
|
||
# Key
|
||
key_item = QTableWidgetItem(key)
|
||
key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.metadata_table.setItem(i, 0, key_item)
|
||
|
||
# Type
|
||
if not field.types:
|
||
type_str = "N/A"
|
||
elif field.types[0] == GGUFValueType.ARRAY:
|
||
nest_count = len(field.types) - 1
|
||
element_type = field.types[-1].name
|
||
# Check if this is an enum array
|
||
enum_type = self.get_enum_for_key(key)
|
||
if enum_type is not None and field.types[-1] == GGUFValueType.INT32:
|
||
element_type = enum_type.__name__
|
||
type_str = '[' * nest_count + element_type + ']' * nest_count
|
||
else:
|
||
type_str = str(field.types[0].name)
|
||
# Check if this is an enum field
|
||
enum_type = self.get_enum_for_key(key)
|
||
if enum_type is not None and field.types[0] == GGUFValueType.INT32:
|
||
type_str = enum_type.__name__
|
||
|
||
type_item = QTableWidgetItem(type_str)
|
||
type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.metadata_table.setItem(i, 1, type_item)
|
||
|
||
# Value
|
||
value_str = self.format_field_value(field)
|
||
value_item = QTableWidgetItem(value_str)
|
||
|
||
# Make only simple values editable
|
||
if len(field.types) == 1 and field.types[0] != GGUFValueType.ARRAY:
|
||
value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable)
|
||
else:
|
||
value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
|
||
self.metadata_table.setItem(i, 2, value_item)
|
||
|
||
# Actions
|
||
actions_widget = QWidget()
|
||
actions_layout = QHBoxLayout(actions_widget)
|
||
actions_layout.setContentsMargins(2, 2, 2, 2)
|
||
|
||
# Add Edit button for arrays and enum fields
|
||
if field.types and field.types[0] == GGUFValueType.ARRAY:
|
||
edit_button = QPushButton("Edit")
|
||
edit_button.setProperty("row", i)
|
||
edit_button.setProperty("key", key)
|
||
edit_button.clicked.connect(self.edit_array_metadata)
|
||
actions_layout.addWidget(edit_button)
|
||
|
||
# Add special label for tokenizer linked fields
|
||
if key in TOKENIZER_LINKED_KEYS:
|
||
edit_button.setText("Edit Tokenizer")
|
||
edit_button.setToolTip("Edit all tokenizer data together")
|
||
elif len(field.types) == 1 and self.get_enum_for_key(key) is not None:
|
||
edit_button = QPushButton("Edit")
|
||
edit_button.setProperty("row", i)
|
||
edit_button.setProperty("key", key)
|
||
edit_button.clicked.connect(self.edit_metadata_enum)
|
||
actions_layout.addWidget(edit_button)
|
||
|
||
remove_button = QPushButton("Remove")
|
||
remove_button.setProperty("row", i)
|
||
remove_button.setProperty("key", key)
|
||
remove_button.clicked.connect(self.remove_metadata)
|
||
actions_layout.addWidget(remove_button)
|
||
|
||
self.metadata_table.setCellWidget(i, 3, actions_widget)
|
||
|
||
# Reconnect after loading
|
||
self.metadata_table.itemChanged.connect(self.on_metadata_changed)
|
||
self.on_metadata_changed_is_connected = True
|
||
|
||
def extract_array_values(self, field: ReaderField) -> list:
|
||
"""Extract all values from an array field."""
|
||
if not field.types or field.types[0] != GGUFValueType.ARRAY:
|
||
return []
|
||
|
||
curr_type = field.types[1]
|
||
array_values = []
|
||
total_elements = len(field.data)
|
||
|
||
if curr_type == GGUFValueType.STRING:
|
||
for element_pos in range(total_elements):
|
||
value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8')
|
||
array_values.append(value_string)
|
||
elif self.reader and curr_type in self.reader.gguf_scalar_to_np:
|
||
for element_pos in range(total_elements):
|
||
array_values.append(field.parts[-1 - (total_elements - element_pos - 1)][0])
|
||
|
||
return array_values
|
||
|
||
def get_enum_for_key(self, key: str) -> Optional[Type[enum.Enum]]:
|
||
"""Get the enum type for a given key if it exists."""
|
||
return KEY_TO_ENUM_TYPE.get(key)
|
||
|
||
def format_enum_value(self, value: Any, enum_type: Type[enum.Enum]) -> str:
|
||
"""Format a value as an enum if possible."""
|
||
try:
|
||
if isinstance(value, (int, str)):
|
||
enum_value = enum_type(value)
|
||
return f"{enum_value.name} ({value})"
|
||
except (ValueError, KeyError):
|
||
pass
|
||
return str(value)
|
||
|
||
def format_field_value(self, field: ReaderField) -> str:
|
||
if not field.types:
|
||
return "N/A"
|
||
|
||
if len(field.types) == 1:
|
||
curr_type = field.types[0]
|
||
if curr_type == GGUFValueType.STRING:
|
||
return str(bytes(field.parts[-1]), encoding='utf-8')
|
||
elif self.reader and curr_type in self.reader.gguf_scalar_to_np:
|
||
value = field.parts[-1][0]
|
||
# Check if this field has an enum type
|
||
enum_type = self.get_enum_for_key(field.name)
|
||
if enum_type is not None:
|
||
return self.format_enum_value(value, enum_type)
|
||
return str(value)
|
||
|
||
if field.types[0] == GGUFValueType.ARRAY:
|
||
array_values = self.extract_array_values(field)
|
||
render_element = min(5, len(array_values))
|
||
|
||
# Get enum type for this array if applicable
|
||
enum_type = self.get_enum_for_key(field.name)
|
||
|
||
if enum_type is not None:
|
||
array_elements = []
|
||
for i in range(render_element):
|
||
array_elements.append(self.format_enum_value(array_values[i], enum_type))
|
||
else:
|
||
array_elements = [str(array_values[i]) for i in range(render_element)]
|
||
|
||
return f"[ {', '.join(array_elements).strip()}{', ...' if len(array_values) > len(array_elements) else ''} ]"
|
||
|
||
return "Complex value"
|
||
|
||
def load_tensors(self):
|
||
self.tensors_table.setRowCount(0)
|
||
|
||
if not self.reader:
|
||
return
|
||
|
||
for i, tensor in enumerate(self.reader.tensors):
|
||
self.tensors_table.insertRow(i)
|
||
|
||
# Name
|
||
name_item = QTableWidgetItem(tensor.name)
|
||
name_item.setFlags(name_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tensors_table.setItem(i, 0, name_item)
|
||
|
||
# Type
|
||
type_item = QTableWidgetItem(tensor.tensor_type.name)
|
||
type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tensors_table.setItem(i, 1, type_item)
|
||
|
||
# Shape
|
||
shape_str = " × ".join(str(d) for d in tensor.shape)
|
||
shape_item = QTableWidgetItem(shape_str)
|
||
shape_item.setFlags(shape_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tensors_table.setItem(i, 2, shape_item)
|
||
|
||
# Elements
|
||
elements_item = QTableWidgetItem(str(tensor.n_elements))
|
||
elements_item.setFlags(elements_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tensors_table.setItem(i, 3, elements_item)
|
||
|
||
# Size
|
||
size_item = QTableWidgetItem(f"{tensor.n_bytes:,}")
|
||
size_item.setFlags(size_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.tensors_table.setItem(i, 4, size_item)
|
||
|
||
def on_metadata_changed(self, item):
|
||
if item.column() != 2: # Only handle value column changes
|
||
return
|
||
|
||
row = item.row()
|
||
orig_item = self.metadata_table.item(row, 0)
|
||
key = None
|
||
if orig_item:
|
||
key = orig_item.text()
|
||
new_value = item.text()
|
||
|
||
field = None
|
||
if self.reader and key:
|
||
field = self.reader.get_field(key)
|
||
if not field or not field.types or not key:
|
||
return
|
||
|
||
value_type = field.types[0]
|
||
|
||
# Check if this is an enum field
|
||
enum_type = self.get_enum_for_key(key)
|
||
if enum_type is not None and value_type == GGUFValueType.INT32:
|
||
# Try to parse the enum value from the text
|
||
try:
|
||
# Check if it's a name
|
||
try:
|
||
enum_val = enum_type[new_value]
|
||
converted_value = enum_val.value
|
||
except (KeyError, AttributeError):
|
||
# Check if it's a number or "NAME (value)" format
|
||
if '(' in new_value and ')' in new_value:
|
||
# Extract the value from "NAME (value)" format
|
||
value_part = new_value.split('(')[1].split(')')[0].strip()
|
||
converted_value = int(value_part)
|
||
else:
|
||
# Try to convert directly to int
|
||
converted_value = int(new_value)
|
||
|
||
# Validate that it's a valid enum value
|
||
enum_type(converted_value)
|
||
|
||
# Store the change
|
||
self.metadata_changes[key] = (value_type, converted_value)
|
||
self.modified = True
|
||
|
||
# Update display with formatted enum value
|
||
formatted_value = self.format_enum_value(converted_value, enum_type)
|
||
item.setText(formatted_value)
|
||
|
||
self.statusBar().showMessage(f"Changed {key} to {formatted_value}")
|
||
return
|
||
except (ValueError, KeyError) as e:
|
||
QMessageBox.warning(
|
||
self,
|
||
f"Invalid Enum Value ({e})",
|
||
f"'{new_value}' is not a valid {enum_type.__name__} value.\n"
|
||
f"Valid values are: {', '.join(v.name for v in enum_type)}")
|
||
|
||
# Revert to original value
|
||
original_value = self.format_field_value(field)
|
||
item.setText(original_value)
|
||
return
|
||
|
||
try:
|
||
# Convert the string value to the appropriate type
|
||
if value_type == GGUFValueType.UINT8:
|
||
converted_value = np.uint8(int(new_value))
|
||
elif value_type == GGUFValueType.INT8:
|
||
converted_value = np.int8(int(new_value))
|
||
elif value_type == GGUFValueType.UINT16:
|
||
converted_value = np.uint16(int(new_value))
|
||
elif value_type == GGUFValueType.INT16:
|
||
converted_value = np.int16(int(new_value))
|
||
elif value_type == GGUFValueType.UINT32:
|
||
converted_value = np.uint32(int(new_value))
|
||
elif value_type == GGUFValueType.INT32:
|
||
converted_value = np.int32(int(new_value))
|
||
elif value_type == GGUFValueType.FLOAT32:
|
||
converted_value = np.float32(float(new_value))
|
||
elif value_type == GGUFValueType.BOOL:
|
||
converted_value = new_value.lower() in ('true', 'yes', '1')
|
||
elif value_type == GGUFValueType.STRING:
|
||
converted_value = new_value
|
||
else:
|
||
# Unsupported type for editing
|
||
return
|
||
|
||
# Store the change
|
||
self.metadata_changes[key] = (value_type, converted_value)
|
||
self.modified = True
|
||
|
||
self.statusBar().showMessage(f"Changed {key} to {new_value}")
|
||
except ValueError:
|
||
QMessageBox.warning(self, "Invalid Value", f"The value '{new_value}' is not valid for type {value_type.name}")
|
||
|
||
# Revert to original value
|
||
original_value = self.format_field_value(field)
|
||
item.setText(original_value)
|
||
|
||
def remove_metadata(self):
|
||
button = self.sender()
|
||
key = button.property("key")
|
||
row = button.property("row")
|
||
|
||
reply = QMessageBox.question(
|
||
self, "Confirm Removal",
|
||
f"Are you sure you want to remove the metadata key '{key}'?",
|
||
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No
|
||
)
|
||
|
||
if reply == QMessageBox.StandardButton.Yes:
|
||
self.metadata_table.removeRow(row)
|
||
self.metadata_to_remove.add(key)
|
||
|
||
# If we previously had changes for this key, remove them
|
||
if key in self.metadata_changes:
|
||
del self.metadata_changes[key]
|
||
|
||
self.modified = True
|
||
self.statusBar().showMessage(f"Marked {key} for removal")
|
||
|
||
def edit_metadata_enum(self):
|
||
"""Edit an enum metadata field."""
|
||
button = self.sender()
|
||
key = button.property("key")
|
||
row = button.property("row")
|
||
|
||
field = None
|
||
if self.reader:
|
||
field = self.reader.get_field(key)
|
||
if not field or not field.types:
|
||
return
|
||
|
||
enum_type = self.get_enum_for_key(key)
|
||
if enum_type is None:
|
||
return
|
||
|
||
# Get current value
|
||
current_value = field.contents()
|
||
|
||
# Create a dialog with enum options
|
||
dialog = QDialog(self)
|
||
dialog.setWindowTitle(f"Select {enum_type.__name__} Value")
|
||
layout = QVBoxLayout(dialog)
|
||
|
||
combo = QComboBox()
|
||
for enum_val in enum_type:
|
||
combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
|
||
|
||
# Set current value
|
||
try:
|
||
if isinstance(current_value, (int, str)):
|
||
enum_val = enum_type(current_value)
|
||
combo.setCurrentText(f"{enum_val.name} ({current_value})")
|
||
except (ValueError, KeyError):
|
||
pass
|
||
|
||
layout.addWidget(combo)
|
||
|
||
buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
|
||
buttons.accepted.connect(dialog.accept)
|
||
buttons.rejected.connect(dialog.reject)
|
||
layout.addWidget(buttons)
|
||
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
# Get the selected value
|
||
new_value = combo.currentData()
|
||
enum_val = enum_type(new_value)
|
||
|
||
# Store the change
|
||
self.metadata_changes[key] = (field.types[0], new_value)
|
||
self.modified = True
|
||
|
||
# Update display
|
||
display_text = f"{enum_val.name} ({new_value})"
|
||
target_item = self.metadata_table.item(row, 2)
|
||
if target_item:
|
||
target_item.setText(display_text)
|
||
|
||
self.statusBar().showMessage(f"Changed {key} to {display_text}")
|
||
|
||
def edit_array_metadata(self):
|
||
button = self.sender()
|
||
key = button.property("key")
|
||
row = button.property("row")
|
||
|
||
# Check if this is one of the linked tokenizer keys
|
||
if key in TOKENIZER_LINKED_KEYS:
|
||
self.edit_tokenizer_metadata(key)
|
||
return
|
||
|
||
field = None
|
||
if self.reader:
|
||
field = self.reader.get_field(key)
|
||
if not field or not field.types or field.types[0] != GGUFValueType.ARRAY:
|
||
return
|
||
|
||
# Get array element type
|
||
element_type = field.types[1]
|
||
|
||
# Extract array values
|
||
array_values = self.extract_array_values(field)
|
||
|
||
# Open array editor dialog
|
||
dialog = ArrayEditorDialog(array_values, element_type, key, self)
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
new_values = dialog.get_array_values()
|
||
|
||
# Store the change
|
||
self.metadata_changes[key] = (GGUFValueType.ARRAY, (element_type, new_values))
|
||
self.modified = True
|
||
|
||
# Update display
|
||
enum_type = self.get_enum_for_key(key)
|
||
if enum_type is not None and element_type == GGUFValueType.INT32:
|
||
value_str = f"[ {', '.join(self.format_enum_value(v, enum_type) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]"
|
||
else:
|
||
value_str = f"[ {', '.join(str(v) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]"
|
||
target_item = self.metadata_table.item(row, 2)
|
||
if target_item:
|
||
target_item.setText(value_str)
|
||
|
||
self.statusBar().showMessage(f"Updated array values for {key}")
|
||
|
||
def edit_tokenizer_metadata(self, trigger_key):
|
||
"""Edit the linked tokenizer metadata arrays together."""
|
||
if not self.reader:
|
||
return
|
||
|
||
# Get all three fields
|
||
tokens_field = self.reader.get_field(gguf.Keys.Tokenizer.LIST)
|
||
token_types_field = self.reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
|
||
scores_field = self.reader.get_field(gguf.Keys.Tokenizer.SCORES)
|
||
|
||
# Extract values from each field
|
||
tokens = self.extract_array_values(tokens_field) if tokens_field else []
|
||
token_types = self.extract_array_values(token_types_field) if token_types_field else []
|
||
scores = self.extract_array_values(scores_field) if scores_field else []
|
||
|
||
# Apply any pending changes
|
||
if gguf.Keys.Tokenizer.LIST in self.metadata_changes:
|
||
_, (_, tokens) = self.metadata_changes[gguf.Keys.Tokenizer.LIST]
|
||
if gguf.Keys.Tokenizer.TOKEN_TYPE in self.metadata_changes:
|
||
_, (_, token_types) = self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE]
|
||
if gguf.Keys.Tokenizer.SCORES in self.metadata_changes:
|
||
_, (_, scores) = self.metadata_changes[gguf.Keys.Tokenizer.SCORES]
|
||
|
||
# Open the tokenizer editor dialog
|
||
dialog = TokenizerEditorDialog(tokens, token_types, scores, self)
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
new_tokens, new_token_types, new_scores = dialog.get_data()
|
||
|
||
# Store changes for all three arrays
|
||
if tokens_field:
|
||
self.metadata_changes[gguf.Keys.Tokenizer.LIST] = (
|
||
GGUFValueType.ARRAY,
|
||
(tokens_field.types[1], new_tokens)
|
||
)
|
||
|
||
if token_types_field:
|
||
self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] = (
|
||
GGUFValueType.ARRAY,
|
||
(token_types_field.types[1], new_token_types)
|
||
)
|
||
|
||
if scores_field:
|
||
self.metadata_changes[gguf.Keys.Tokenizer.SCORES] = (
|
||
GGUFValueType.ARRAY,
|
||
(scores_field.types[1], new_scores)
|
||
)
|
||
|
||
self.modified = True
|
||
|
||
# Update display for all three fields
|
||
self.update_tokenizer_display(gguf.Keys.Tokenizer.LIST, new_tokens)
|
||
self.update_tokenizer_display(gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types)
|
||
self.update_tokenizer_display(gguf.Keys.Tokenizer.SCORES, new_scores)
|
||
|
||
self.statusBar().showMessage("Updated tokenizer data")
|
||
|
||
def update_tokenizer_display(self, key, values):
|
||
"""Update the display of a tokenizer field in the metadata table."""
|
||
for row in range(self.metadata_table.rowCount()):
|
||
key_item = self.metadata_table.item(row, 0)
|
||
if key_item and key_item.text() == key:
|
||
value_str = f"[ {', '.join(str(v) for v in values[:5])}{', ...' if len(values) > 5 else ''} ]"
|
||
value_item = self.metadata_table.item(row, 2)
|
||
if value_item:
|
||
value_item.setText(value_str)
|
||
break
|
||
|
||
def add_metadata(self):
|
||
dialog = AddMetadataDialog(self)
|
||
if dialog.exec() == QDialog.DialogCode.Accepted:
|
||
key, value_type, value = dialog.get_data()
|
||
|
||
if not key:
|
||
QMessageBox.warning(self, "Invalid Key", "Key cannot be empty")
|
||
return
|
||
|
||
# Check if key already exists
|
||
for row in range(self.metadata_table.rowCount()):
|
||
orig_item = self.metadata_table.item(row, 0)
|
||
if orig_item and orig_item.text() == key:
|
||
QMessageBox.warning(self, "Duplicate Key", f"Key '{key}' already exists")
|
||
return
|
||
|
||
# Add to table
|
||
row = self.metadata_table.rowCount()
|
||
self.metadata_table.insertRow(row)
|
||
|
||
# Key
|
||
key_item = QTableWidgetItem(key)
|
||
key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.metadata_table.setItem(row, 0, key_item)
|
||
|
||
# Type
|
||
type_item = QTableWidgetItem(value_type.name)
|
||
type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||
self.metadata_table.setItem(row, 1, type_item)
|
||
|
||
# Value
|
||
value_item = QTableWidgetItem(str(value))
|
||
value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable)
|
||
self.metadata_table.setItem(row, 2, value_item)
|
||
|
||
# Actions
|
||
actions_widget = QWidget()
|
||
actions_layout = QHBoxLayout(actions_widget)
|
||
actions_layout.setContentsMargins(2, 2, 2, 2)
|
||
|
||
remove_button = QPushButton("Remove")
|
||
remove_button.setProperty("row", row)
|
||
remove_button.setProperty("key", key)
|
||
remove_button.clicked.connect(self.remove_metadata)
|
||
actions_layout.addWidget(remove_button)
|
||
|
||
self.metadata_table.setCellWidget(row, 3, actions_widget)
|
||
|
||
# Store the change
|
||
self.metadata_changes[key] = (value_type, value)
|
||
self.modified = True
|
||
|
||
self.statusBar().showMessage(f"Added new metadata key {key}")
|
||
|
||
def save_file(self):
|
||
if not self.reader:
|
||
QMessageBox.warning(self, "No File Open", "Please open a GGUF file first")
|
||
return
|
||
|
||
if not self.modified and not self.metadata_changes and not self.metadata_to_remove:
|
||
QMessageBox.information(self, "No Changes", "No changes to save")
|
||
return
|
||
|
||
file_path, _ = QFileDialog.getSaveFileName(
|
||
self, "Save GGUF File As", "", "GGUF Files (*.gguf);;All Files (*)"
|
||
)
|
||
|
||
if not file_path:
|
||
return
|
||
|
||
try:
|
||
self.statusBar().showMessage(f"Saving to {file_path}...")
|
||
QApplication.processEvents()
|
||
|
||
# Get architecture and endianness from the original file
|
||
arch = 'unknown'
|
||
field = self.reader.get_field(gguf.Keys.General.ARCHITECTURE)
|
||
if field:
|
||
arch = field.contents()
|
||
|
||
# Create writer
|
||
writer = GGUFWriter(file_path, arch=arch, endianess=self.reader.endianess)
|
||
|
||
# Get alignment if present
|
||
alignment = None
|
||
field = self.reader.get_field(gguf.Keys.General.ALIGNMENT)
|
||
if field:
|
||
alignment = field.contents()
|
||
if alignment is not None:
|
||
writer.data_alignment = alignment
|
||
|
||
# Copy metadata with changes
|
||
for field in self.reader.fields.values():
|
||
# Skip virtual fields and fields written by GGUFWriter
|
||
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
|
||
continue
|
||
|
||
# Skip fields marked for removal
|
||
if field.name in self.metadata_to_remove:
|
||
continue
|
||
|
||
# Apply changes if any
|
||
sub_type = None
|
||
if field.name in self.metadata_changes:
|
||
value_type, value = self.metadata_changes[field.name]
|
||
if value_type == GGUFValueType.ARRAY:
|
||
# Handle array values
|
||
sub_type, value = value
|
||
else:
|
||
# Copy original value
|
||
value = field.contents()
|
||
value_type = field.types[0]
|
||
if value_type == GGUFValueType.ARRAY:
|
||
sub_type = field.types[-1]
|
||
|
||
if value is not None:
|
||
writer.add_key_value(field.name, value, value_type, sub_type=sub_type)
|
||
|
||
# Add new metadata
|
||
for key, (value_type, value) in self.metadata_changes.items():
|
||
# Skip if the key already existed (we handled it above)
|
||
if self.reader.get_field(key) is not None:
|
||
continue
|
||
|
||
sub_type = None
|
||
if value_type == GGUFValueType.ARRAY:
|
||
# Handle array values
|
||
sub_type, value = value
|
||
|
||
writer.add_key_value(key, value, value_type, sub_type=sub_type)
|
||
|
||
# Add tensors (including data)
|
||
for tensor in self.reader.tensors:
|
||
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess)
|
||
|
||
# Write header and metadata
|
||
writer.open_output_file(Path(file_path))
|
||
writer.write_header_to_file()
|
||
writer.write_kv_data_to_file()
|
||
|
||
# Write tensor data using the optimized method
|
||
writer.write_tensors_to_file(progress=False)
|
||
|
||
writer.close()
|
||
|
||
self.statusBar().showMessage(f"Saved to {file_path}")
|
||
|
||
# Ask if user wants to open the new file
|
||
reply = QMessageBox.question(
|
||
self, "Open Saved File",
|
||
"Would you like to open the newly saved file?",
|
||
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.Yes
|
||
)
|
||
|
||
if reply == QMessageBox.StandardButton.Yes:
|
||
self.reader = GGUFReader(file_path, 'r')
|
||
self.current_file = file_path
|
||
self.file_path_edit.setText(file_path)
|
||
|
||
self.load_metadata()
|
||
self.load_tensors()
|
||
|
||
self.metadata_changes = {}
|
||
self.metadata_to_remove = set()
|
||
self.modified = False
|
||
|
||
except Exception as e:
|
||
QMessageBox.critical(self, "Error", f"Failed to save file: {str(e)}")
|
||
self.statusBar().showMessage("Error saving file")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="GUI GGUF Editor")
|
||
parser.add_argument("model_path", nargs="?", help="path to GGUF model file to load at startup")
|
||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||
|
||
args = parser.parse_args()
|
||
|
||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||
|
||
app = QApplication(sys.argv)
|
||
window = GGUFEditorWindow()
|
||
window.show()
|
||
|
||
# Load model if specified
|
||
if args.model_path:
|
||
if os.path.isfile(args.model_path) and args.model_path.endswith('.gguf'):
|
||
window.load_file(args.model_path)
|
||
else:
|
||
logger.error(f"Invalid model path: {args.model_path}")
|
||
QMessageBox.warning(
|
||
window,
|
||
"Invalid Model Path",
|
||
f"The specified file does not exist or is not a GGUF file: {args.model_path}")
|
||
|
||
sys.exit(app.exec())
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|