
해당 kaggle 대회는 텍스트 설명을 바탕으로 SVG(Scalable Vector Graphics) 코드를 생성하는 모델을 개발하는 과제이다. 우선 핵심 내용은 다음과 같다.
텍스트 설명을 입력으로 받아 그 설명에 맞는 이미지를 렌더링하는 SVG 코드를 생성하는 것이다. SVG는 XML 형식을 사용하여 2D 그래픽을 표현하는 벡터 이미지 포맷으로, 크기를 조정해도 품질 손실이 없다는 장점이 있다.
대회의 테스트 데이터는 약 500개의 일상적인 물체와 다양한 영역의 장면 설명으로 구성되어 있음. 이 데이터의 특성과 제공되는 파일은 다음과 같다.
train.csv :
kaggle_evaluation/test.csv :
Description 데이터의 특성
이 대회의 평가 방식은 SVG Image Fidelity Score 을 통해서 이루어진다. 이는 주어진 텍스트 설명과 제출된 SVG 코드 간의 일치도를 측정한다.
cairosvg Python 라이브러리를 사용하여 SVG를 PNG로 변환추가 제약 조건
- 설명이 모델에 전달된 후 5분 이내에 SVG 결과 반환 필요
- 모든 SVG는 9시간 이내에 생성되어야 함.
기간이 약 1달 정도 남았기 때문에 이미 대회를 진행하고 있는 사람들의 접근법을 빨리 따라가보도록 한다.
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
# Load with optimized scheduler and half precision
stable_diffusion_path = kagglehub.model_download("stabilityai/stable-diffusion-v2/pytorch/1/1")
scheduler = DDIMScheduler.from_pretrained(stable_diffusion_path, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
stable_diffusion_path,
scheduler=scheduler,
torch_dtype=torch.float16, # Use half precision
safety_checker=None # Disable safety checker for speed
)
# Move to GPU and apply optimizations
pipe.to(device)
- GPU 설정
- Diffusion model load
- DDIMScheduler (노이즈 제거 단계 제어)
def generate_bitmap(prompt, negative_prompt="", num_inference_steps=20, guidance_scale=15):
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
return image
생성된 이미지 평가
비트맵 이미지를 SVG 로 변환
참고한 노트북
https://www.kaggle.com/code/pepushi/drawing-with-llm-qwen2-5-32b-instruct-awq
class Model:
def __init__(self):
# self.model_path는 변경하지 않음(대회 요구사항)
self.model_path = kagglehub.model_download('qwen-lm/qwen2.5/Transformers/32b-instruct-awq/1')
# LLM 인스턴스 초기화
self.llm = vllm.LLM(
self.model_path,
quantization="awq", # AWQ 양자화 사용
tensor_parallel_size=2, # GPU 병렬 처리(오류 시 1로 변경)
gpu_memory_utilization=0.95, # GPU 메모리 사용률
trust_remote_code=True,
dtype="half", # 반정밀도(FP16) 사용
enforce_eager=True,
max_model_len=5120, # 최대 컨텍스트 길이
disable_log_stats=True
)
Qwen 2.5 32B AWQ 모델을 사용합니다. 이는 대형 언어 모델로, 양자화를 통해 메모리 효율성을 높였습니다.
vLLM 라이브러리를 사용하여 모델을 효율적으로 로드하고 추론합니다.
두 개의 GPU에 병렬로 모델을 분산하여 속도를 높입니다(tensor_parallel_size=2).
# 베이스 샘플링 파라미터(RL 등으로 최적화 가능성 있음)
self.sampling_params = vllm.SamplingParams(
n=1,
top_k=30, # 상위 30개 토큰만 고려
top_p=0.85, # 누적 확률 0.85까지의 토큰만 고려
temperature=0.75, # 다양성 조절(높을수록 다양)
repetition_penalty=1.05, # 반복 페널티
skip_special_tokens=False,
max_tokens=700, # SVG 코드 생성 최대 토큰 수
)
self.tokenizer = self.llm.get_tokenizer()
다양성과 품질 사이의 균형을 조절하는 파라미터를 설정합니다.
temperature=0.75는 적당한 수준의 창의성을 허용합니다.
max_tokens=700은 충분히 상세한 SVG를 생성할 수 있는 길이입니다.
# 프롬프트: 미적 요소를 강조
self.prompt_template = """Generate a beautiful, aesthetically pleasing SVG code that visually represents the following text description, while respecting these strict constraints:
<constraints>
* **Allowed Elements:** svg, path, circle, rect, ellipse, line, polyline, polygon, g, linearGradient, radialGradient, stop, defs
* **Allowed Attributes:** viewBox, width, height, fill, stroke, stroke-width, d, cx, cy, r, x, y, rx, ry, x1, y1, x2, y2, points, transform, opacity
</constraints>
<example>
<description>"A red circle with a blue square inside"</description>
svg
<svg viewBox="0 0 256 256" width="256" height="256">
<circle cx="50" cy="50" r="40" fill="red"/>
<rect x="30" y="30" width="40" height="40" fill="blue"/>
</svg>
</example>
Focus on visually appealing composition, color usage, and (optionally) subtle gradients or transforms.
Always provide **complete** SVG with nothing omitted, no ellipses.
<description>"{}"</description>
svg
<svg viewBox="0 0 256 256" width="256" height="256">
"""
상세한 프롬프트를 사용하여 모델에게 SVG 생성에 대한 정확한 지시를 제공합니다.
허용된 SVG 요소와 속성을 명시하여 대회 제약 조건을 준수하도록 합니다.
예시를 통해 모델이 출력 형식을 이해하도록 돕습니다.
미적 요소(구성, 색상 사용, 그라디언트, 변형)를 강조합니다.
# <svg> 추출
def _parse_svg(self, response: str) -> str:
matchs = re.findall(r"<svg.*?</svg>", response, re.DOTALL)
return matchs[-1].strip() if matchs else ""
# cairosvg 체크
def _check_svg_valid(self, svg_code: str) -> bool:
try:
cairosvg.svg2png(bytestring=svg_code.encode("utf-8"))
return True
except:
return False
정규 표현식을 사용하여 모델 응답에서 SVG 코드를 추출합니다.
cairosvg 라이브러리를 사용하여 SVG의 렌더링 가능 여부를 검사합니다.
이는 생성된 SVG가 실제로 유효한지 확인하는 중요한 단계입니다.
# 보상 함수: 미적 요소 + 컴팩트함 + 다양성 + gradient/transform 등
def _reward_function(self, svg: str) -> float:
if not self._check_svg_valid(svg):
return 0.0
base_score = 1.0
# 다양한 형태(태그)
shape_tags = ["circle", "rect", "ellipse", "line", "polygon", "polyline", "path"]
used_tags = set()
shape_count = 0
for tag in shape_tags:
c = len(re.findall(f"<{tag}", svg, re.IGNORECASE))
if c > 0:
used_tags.add(tag)
shape_count += c
diversity_score = len(used_tags) # 최대 ~7
shape_count_score = min(shape_count, 20) # 최대 20
# 색상 사용 => fill= / stroke= 패턴
color_matches = re.findall(r'(?:fill|stroke)\s*=\s*"[^"]+"', svg)
color_count = len(color_matches)
color_score = min(color_count / 4.0, 3.0) # 4색마다 +1, 최대 3
# gradient 또는 defs => 미적 점수 큰 증가
gradient_defs_score = 0.5 if any(g in svg for g in ["<linearGradient", "<radialGradient", "<defs"]) else 0.0
# transform => +0.3
transform_score = 0.3 if re.search(r'transform\s*=', svg, re.IGNORECASE) else 0.0
# stroke+fill => +0.3
stroke_fill_score = 0.3 if ('stroke=' in svg and 'fill=' in svg) else 0.0
# 크기 패널티: 3000~4000까지 허용, 그 이상은 급격히 감점
size = len(svg.encode("utf-8"))
if size>4000:
size_score = 0.3
elif size>2000:
size_score = 0.6
else:
size_score = 1.0
total = (
base_score +
0.5*diversity_score +
0.4*shape_count_score +
color_score +
gradient_defs_score +
transform_score +
stroke_fill_score +
size_score
)
return round(total, 4)
생성된 SVG의 품질을 정량적으로 평가하는 복합 점수 시스템입니다.
다양한 형태(태그)의 사용, 색상 다양성, 그라디언트/변형 사용 등을 고려합니다.
파일 크기에 대한 패널티를 부여하여 10KB 제한 내에서 최적화합니다.
이 점수를 기반으로 여러 생성 결과 중 최상의 SVG를 선택합니다.
# SVG 제약 조건 정제
def enforce_constraints(self, svg_string: str) -> str:
logging.info("Sanitizing SVG...")
try:
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
root = etree.fromstring(svg_string, parser=parser)
except etree.ParseError as e:
logging.error("SVG ParseError => default. %s", e)
return self.default_svg
to_remove = []
for el in root.iter():
tag_name = etree.QName(el.tag).localname
# 불허용 태그
if tag_name not in self.constraints.allowed_elements:
to_remove.append(el)
continue
# 불허용 속성
rm_attrs = []
for attr in el.attrib:
aname = etree.QName(attr).localname
if aname not in self.constraints.allowed_elements[tag_name] \
and aname not in self.constraints.allowed_elements["common"]:
rm_attrs.append(attr)
for a in rm_attrs:
del el.attrib[a]
# href 유효성 검사
for k, v in list(el.attrib.items()):
if etree.QName(k).localname=="href" and not v.startswith("#"):
del el.attrib[k]
# path 체크
if tag_name=="path":
d_val = el.get("d","")
if not d_val:
to_remove.append(el)
continue
path_regex = re.compile(
r'^(?:[MmZzLlHhVvCcSsQqTtAa]\s*'
r'(?:-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?'
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*'
r')?\s*)+$'
)
if not path_regex.match(d_val.strip()):
to_remove.append(el)
for r in to_remove:
if r.getparent() is not None:
r.getparent().remove(r)
try:
cleaned = etree.tostring(root, encoding="unicode")
return cleaned
except ValueError as e:
logging.error("Sanitize error => default. %s", e)
return self.default_svg
XML 파서를 사용하여 SVG를 구문 분석합니다.
대회 규칙에 따라 허용되지 않은 태그와 속성을 제거합니다.
외부 리소스에 대한 참조를 방지합니다.
path 요소의 문법적 유효성을 검사합니다.
정제 과정에서 오류가 발생하면 기본 SVG를 반환합니다.
# refine: 미적 요소 추가 강화
def _refine_svg(self, svg_code: str) -> str:
refine_prompt = (
"Here is an SVG code:\n"
f"{svg_code}\n\n"
"Please refine or beautify it further (e.g., add subtle gradient, transform, interesting shapes) "
"while preserving the same main composition. Return only the complete <svg> code, no comments or ellipses."
)
msgs = [{"role": "user", "content": refine_prompt}]
refine_txt = self.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
resp = self.llm.generate([refine_txt], self.sampling_params, use_tqdm=False)
if not resp:
return svg_code
new_svg = self._parse_svg(resp[0].outputs[0].text)
if not new_svg:
return svg_code
new_svg = self.enforce_constraints(new_svg)
if self._check_svg_valid(new_svg):
old_sc = self._reward_function(svg_code)
new_sc = self._reward_function(new_svg)
if new_sc>old_sc:
logging.info(f"Refine success: {old_sc} => {new_sc}")
return new_svg
return svg_code
생성된 최상의 SVG를 추가로 개선하는 단계입니다.
모델에게 미적 요소(그라디언트, 변형, 흥미로운 형태)를 추가하도록 요청합니다.
개선된 버전이 원본보다 높은 보상 점수를 받는 경우에만 채택합니다.
이 단계는 SVG의 예술적 품질을 높이는 데 중요합니다.
# RL: 파라미터 탐색
def train_with_rl(self, sample_prompts: List[str], search_size: int=5):
best_params = None
best_score = -9999.0
cands = []
for _ in range(search_size):
temp = random.uniform(0.6,1.0)
top_p = random.uniform(0.75,0.99)
top_k = random.randint(20,70)
cands.append((temp, top_p, top_k))
for (temp, t_p, t_k) in cands:
logging.info(f"Trying param => t={temp}, p={t_p}, k={t_k}")
sp = vllm.SamplingParams(
n=1,
top_k=t_k,
top_p=t_p,
temperature=temp,
repetition_penalty=1.05,
skip_special_tokens=False,
max_tokens=700
)
ssum = 0.0
cc = 0
for pr in sample_prompts:
txt = self._apply_template(pr)
outp = self.llm.generate([txt]*2, sp, use_tqdm=False)
for r in outp:
svg_raw = self._parse_svg(r.outputs[0].text)
cleaned = self.enforce_constraints(svg_raw)
ssum += self._reward_function(cleaned)
cc += 1
avg = ssum/cc if cc>0 else 0.0
logging.info(f"Avg reward => {avg}")
if avg>best_score:
best_score = avg
best_params = (temp, t_p, t_k)
if best_params is not None:
logging.info(f"Best found => {best_params}, score={best_score}")
(temp, p, k) = best_params
self.sampling_params = vllm.SamplingParams(
n=1,
top_k=k,
top_p=p,
temperature=temp,
repetition_penalty=1.05,
skip_special_tokens=False,
max_tokens=700
)
샘플링 파라미터(temperature, top_p, top_k)를 최적화하는 간단한 RL 접근법입니다.
다양한 파라미터 조합을 시도하고 평균 보상 점수를 측정합니다.
가장 높은 평균 점수를 얻은 파라미터 세트를 선택합니다.
이는 모델이 더 높은 품질의 SVG를 생성하도록 조정하는 메타 최적화입니다.
# predict (실제 SVG 생성)
def predict(self, description: str) -> str:
def generate_svg():
try:
prompt_txt = self._apply_template(description)
# 동일 프롬프트로 여러 번 생성
responses = self.llm.generate([prompt_txt]*num_attempt,
self.sampling_params,
use_tqdm=False)
cands = []
for r in responses:
raw_svg = self._parse_svg(r.outputs[0].text)
if not raw_svg:
continue
cleaned = self.enforce_constraints(raw_svg)
if self._check_svg_valid(cleaned):
cands.append(cleaned)
if not cands:
logging.warning("No valid => default return")
return self.default_svg
# 최상 선택
best_svg = None
best_sc = -9999.0
for c in cands:
sc = self._reward_function(c)
if sc>best_sc:
best_sc = sc
best_svg = c
# refine
refined = self._refine_svg(best_svg)
ref_sc = self._reward_function(refined)
if ref_sc>best_sc:
return refined
else:
return best_svg
except Exception as e:
logging.error("Exception in generate_svg: %s", e)
return self.default_svg
return generate_svg()
텍스트 설명을 SVG 코드로 변환합니다.
num_attempt(10)번 동일한 프롬프트로 여러 SVG를 생성합니다.
각 생성물에 대해 제약 조건을 적용하고 유효성을 검사합니다.
보상 함수를 사용하여 가장 높은 점수의 SVG를 선택합니다.
선택된 SVG를 개선하는 시도를 하고, 개선된 버전이 더 나은 경우 사용합니다.
오류 발생 시 기본 SVG를 반환하여 안정성을 확보합니다.
다른 노트북들을 살펴 봤을 때도 전체적인 흐름이 이 2가지로 정리되는 듯하다.
Text → Image → SVG
( Stable Diffusion 활용 )
LLM 활용