Fine tuning the text to SQL using JAX echo System - Part 1
Table of contents
Introduction
In this series, I will use a small LLM model and fine tune it to answer the request by users by creating the right sql query,
I will use Spider1 dataset, this dataset have a ranking and benchmarks for showing the effectiveness of each method, and now have been
replaced with Spider 2, but I will use the older version for our example since it smaller dataset and more established

Steps
Here is the current plan for this series,
- show the data structure and load the data using Grain
- Prompting + tokenization + zero-shot generation
- LoRA fine-tuning with Tunix
- SQL fine tune evaluation with examples
- Execute generated SQL on SQLite and compare results
Spider dataset
This dataset comes from Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task paper from 2018,
Here are some values for you to view for each test/train sample in the Spider dataset. We've shortened the original JSON structure to highlight the most relevant information for human readability.
{
"db_id": "geo",
"query": "SELECT city_name FROM city WHERE population = ( SELECT MAX ( population ) FROM city WHERE state_name = \"wyoming\" ) AND state_name = \"wyoming\";",
"question": "what is the biggest city in wyoming"
}
Explanation of Shortened Parts:
We've focused on the db_id, query, and question fields, as these are the most directly interpretable for understanding the dataset's core purpose: mapping natural language questions to SQL queries for a specific database.
The following fields were omitted for brevity:
query_toks: This field contains the tokenized version of the SQL query. While useful for machine processing, it's redundant for human understanding when thequeryfield is present.query_toks_no_value: Similar toquery_toks, but with values replaced by a generic "value" token. This is primarily for model training and not for human review.question_toks: The tokenized version of the natural language question. Again, thequestionfield itself is sufficient for human comprehension.sql: This is a highly detailed, nested JSON representation of the SQL query's abstract syntax tree (AST). While crucial for semantic parsing tasks, it's overly complex for a quick human overview of the data samples.
Grain
Grain is a data loading library in the JAX ecosystem. I wrote this simple data source to read the Spider JSON files, keep only the fields I need, and attach the database schema definition for each record.
class JsonDataSource:
def __init__(self, json_paths, keep_fields=("db_id", "query", "question")):
self.records = load_json_files(json_paths)
self.keep_fields = keep_fields
self.db_records = {}
def __len__(self):
return len(self.records)
def __getitem__(self, index):
raw = self.records[index]
record = {
key: raw[key]
for key in self.keep_fields
if key in raw
}
file_path = f'{os.environ[SPIDER_PATH]}/database/{record["db_id"]}/schema.sql'
if file_path not in self.db_records:
record["db_definitions"] = "\n".join(get_create_table_blocks(file_path))
self.db_records[file_path] = record["db_definitions"]
else:
record["db_definitions"] = self.db_records[file_path]
return record
After defining the data source, I can pass it to Grain and create a loader. For now, I use the development split, disable shuffling, and set batch_size=1 so it is easy to inspect one example at a time.
dev_source = JsonDataSource([
base_path + "dev.json",
])
dev_loader = grain.load(
dev_source,
num_epochs=1,
shuffle=False,
batch_size=1,
worker_count=0,
)
To inspect one sample from the loader, I can use Python’s iter and next:
sample = next(iter(dev_loader))
sample
Here is what one loaded sample looks like with batch_size=1:
{
"db_id": array(["concert_singer"], dtype="<U14"),
"query": array(["SELECT count(*) FROM singer"], dtype="<U27"),
"question": array(["How many singers do we have?"], dtype="<U28"),
"db_definitions": array([
'''CREATE TABLE "stadium" (
"Stadium_ID" int,
"Location" text,
"Name" text,
"Capacity" int,
"Highest" int,
"Lowest" int,
"Average" int,
PRIMARY KEY ("Stadium_ID")
);
CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Country" text,
"Song_Name" text,
"Song_release_year" text,
"Age" int,
"Is_male" bool,
PRIMARY KEY ("Singer_ID")
);
CREATE TABLE "concert" (
"concert_ID" int,
"concert_Name" text,
"Theme" text,
"Stadium_ID" text,
"Year" text,
PRIMARY KEY ("concert_ID"),
FOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID")
);
CREATE TABLE "singer_in_concert" (
"concert_ID" int,
"Singer_ID" text,
PRIMARY KEY ("concert_ID","Singer_ID"),
FOREIGN KEY ("concert_ID") REFERENCES "concert"("concert_ID"),
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
);'''
], dtype="<U771"),
}